예제 #1
0
    def test_algo(self):
        paddle_model, tokenize = self.set_paddle_model()
        img_path = 'tutorials/assets/catdog.png'
        texts = ["a cat"]
        text_tokenized = tokenize(texts)
        algo = it.GAInterpreter(paddle_model, device='cpu')
        R_txt, R_img = algo.interpret(img_path,
                                      texts,
                                      text_tokenized,
                                      crop_to=224,
                                      visual=False)

        result = np.array(
            [R_img.mean(), R_img.std(),
             R_img.min(), R_img.max()])
        desired = np.array([0.00365583, 0.00676875, 0.00013451, 0.03150804],
                           dtype=np.float32)
        assert_arrays_almost_equal(self, result, desired)
예제 #2
0
def zeroshot_classifier(classnames, templates):
    model, preprocess = get_model_and_preprocess()

    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            texts = [template.format(classname)
                     for template in templates]  # format with class
            texts = clip.tokenize(texts).to(get_device())  # tokenize
            class_embeddings = model.encode_text(
                texts)  # embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights,
                                       dim=1).to(get_device())
    return zeroshot_weights
예제 #3
0
    def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
        """Transform a `np.ndarray` of strings of length `BatchSize` into
        a `np.ndarray` of shape `Batchsize x EmbeddingDimension`.

        :param data: A `np.ndarray` of strings.
        :param args: Additional positional arguments.
        :param kwargs: Additional positional arguments.
        :return: A `BachSize x EmbeddingSize` numpy `ndarray`.
        """
        import clip
        input_torchtensor = clip.tokenize(data)
        if self.on_gpu:
            input_torchtensor = input_torchtensor.cuda()

        with torch.no_grad():
            embedded_data = self.model.encode_text(input_torchtensor)

        embedded_data = embedded_data.cpu().numpy()
        return embedded_data
예제 #4
0
def create(network):
    """Loads the CLIP model."""
    json_path = os.path.join(os.path.dirname(__file__), "imagenet.json")
    with open(json_path, "r") as fp:
        imagenet_labels = json.load(fp)

    with torch.set_grad_enabled(False):
        model, _ = clip.load(network, device="cuda", jit=False)
        model = model.eval()

    prompts = clip.tokenize(
        [f"This is a photo of a {label}" for label in imagenet_labels])

    with torch.no_grad():
        prompts_features = model.encode_text(prompts.cuda()).float()
        prompts_features /= prompts_features.norm(dim=-1, keepdim=True)

    image_mean = [0.48145466, 0.4578275, 0.40821073]
    image_std = [0.26862954, 0.26130258, 0.27577711]

    def call(features):
        # Normalize according to the documentation. Note that the pre-processing
        # will already have the range normalized to [0, 1].
        images_normalized = (features["image"] - image_mean) / image_std
        # Reshape from [batch, h, w, c] -> [batch, c, h, w]
        images_torch = torch.tensor(
            np.transpose(images_normalized, [0, 3, 1, 2]).astype(np.float32))
        with torch.no_grad():
            image_features = model.encode_image(
                images_torch.to("cuda")).float()
            image_features /= image_features.norm(dim=-1, keepdim=True)
            similarities = image_features @ prompts_features.T
            # The 100 (inv temperature) comes from the released code.
            return (100.0 * similarities).softmax(dim=-1).cpu().numpy()

    input_resolution = model.visual.input_resolution
    preprocess_config = (f"resize_small({input_resolution})|"
                         f"central_crop({input_resolution})|"
                         f"value_range(0,1)")
    preprocess_fn = pipeline_builder.get_preprocess_fn(preprocess_config,
                                                       remove_tpu_dtypes=False)
    return call, preprocess_fn
예제 #5
0
    def forward(self, captions):
        special_words = ['<unk>', '<start>', '<end>', '<pad>']
        special_words_enc = [self.word_map[w] for w in special_words]
        text = []
        print(captions)
        for cap in captions:
            cap_words = [
                self.rev_word_map[w] for w in cap[1:-1].tolist()
                if w not in special_words_enc
            ]
            text.append(' '.join(cap_words))
        text = clip.tokenize(text).to(device)
        with torch.no_grad():
            embedding = self.clip_model.encode_text(text)
        embedding = embedding.view(embedding.shape[0], self.embed_dim, -1)
        adaptive_pool = nn.AdaptiveAvgPool1d(captions.shape[1])
        embedding = adaptive_pool(embedding)
        embedding = embedding.permute(0, 2, 1)

        return embedding
예제 #6
0
def getConceptFeatures(concept_path):
    concepts = []
    with open(concept_path, 'r') as f:
        for word in f.readlines():
            concepts.append(word.strip().split("'")[1])

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)

    with torch.no_grad():
        concept_embs = {}
        for word in concepts:
            word_emb = clip.tokenize(word).to(device)
            word_feature = model.encode_text(word_emb)
            concept_embs[word] = word_feature.cpu().numpy()

        with open(r'/data/linkang/VHL_GNN/utc/concept_clip.pkl', 'wb') as file:
            pickle.dump(concept_embs, file)

    return
예제 #7
0
    def compute_lang_embed(self, nls):
        with torch.no_grad():
            if self.use_clip:
                tokens = clip.tokenize(nls).cuda()
                outputs = self.clip_model.encode_text(tokens)
                lang_embeds = self.lang_fc(outputs)

            else:
                tokens = self.bert_tokenizer.batch_encode_plus(
                    nls, padding="longest", return_tensors="pt")
                outputs = self.bert_model(
                    tokens["input_ids"].cuda(),
                    attention_mask=tokens["attention_mask"].cuda(),
                )
                lang_embeds = torch.mean(outputs.last_hidden_state, dim=1)
                lang_embeds = self.lang_fc(lang_embeds)
            if self.is_triplet_loss:
                lang_embeds = F.normalize(lang_embeds)

        return lang_embeds
예제 #8
0
def compute_clip_loss(img, text, ref_img=None):
    img = clip_transform(img)
    img = torch.nn.functional.upsample_bilinear(img, (224, 224))
    img_logits = clip_model.encode_image(img)

    tokenized_text = clip.tokenize([text]).to(device).detach().clone()
    text_logits = clip_model.encode_text(tokenized_text)

    loss = 10 * -torch.cosine_similarity(text_logits, img_logits).mean()

    if ref_img is not None:
        ref_img = clip_transform(ref_img)
        ref_img = torch.nn.functional.upsample_bilinear(ref_img, (224, 224))
        ref_img_logits = clip_model.encode_image(ref_img)

        loss += 10 * -torch.cosine_similarity(ref_img_logits,
                                              img_logits).mean()
        loss /= 2

    return loss
예제 #9
0
    def compute_up_to(self, strings, layer) -> torch.Tensor:
        assert layer in ["text_projection", "full"]

        non_cached = []
        for s in strings:
            if s not in self._cache:
                non_cached.append(s)

        def closure(self, tokens):  # pass self.model
            # taken from model.encode_text
            x = self.token_embedding(tokens).type(
                self.dtype)  # [batch_size, n_ctx, d_model]

            x = x + self.positional_embedding.type(self.dtype)
            x = x.permute(1, 0, 2)  # NLD -> LND
            x = self.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD
            x = self.ln_final(x).type(self.dtype)

            # x.shape = [batch_size, n_ctx, transformer.width]
            # take features from the eot embedding (eot_token is the highest number in each sequence)
            x = x[torch.arange(x.shape[0]), tokens.argmax(dim=-1)]
            if layer == "text_projection":
                return x
            elif layer == "full":
                x = x @ self.text_projection
                return F.normalize(x)
            else:
                assert False

        if len(non_cached) > 0:
            tokens = clip.tokenize(non_cached).to(self.device)
            new_vecs = closure(self.model, tokens)
            for (s, v) in zip(non_cached, new_vecs):
                self._cache[s] = v

        ans = []
        for s in strings:
            ans.append(self._cache[s])

        return torch.stack(ans)
예제 #10
0
def imagine(text,
            model_path,
            lr=.07,
            seed=0,
            num_epochs=200,
            total_plots=20,
            batch_size=16,
            outdir=None,
            stylegan2_dir="stylegan2-ada-pytorch",
            clip_dir="CLIP",
            la=1,
            lb=100,
            verbose=False,
            only_last=False):
    sys.path.insert(1, clip_dir)
    import clip
    perceptor, preprocess = clip.load('ViT-B/32')
    model = Stylegan2Gen(model_path, stylegan2_dir)
    im_shape = perpre(model(gen_random(model.z_dim, model.G.c_dim)))[0].size()
    sideX, sideY, channels = im_shape

    torch.manual_seed(seed)
    lats = Pars(model.z_dim, model.G.c_dim, batch_size).cuda()
    optimizer = torch.optim.Adam(lats.parameters(), lr)

    nom = torchvision.transforms.Normalize(
        (0.48145466, 0.4578275, 0.40821073),
        (0.26862954, 0.26130258, 0.27577711))
    tx = clip.tokenize(text)
    t = perceptor.encode_text(tx.cuda()).detach().clone()

    outdir = (text if outdir is None else outdir)
    plot_every = int(num_epochs / total_plots)
    if not os.path.isdir(outdir):
        os.mkdir(outdir)
    for i in trange(num_epochs):
        train(i, outdir, plot_every, model, perceptor, optimizer, t, nom, lats,
              la, lb, verbose, only_last)
    final(outdir, plot_every, model, perceptor, optimizer, t, nom, lats, la,
          lb)
예제 #11
0
def clip_clasify_element(folder):

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, transform = clip.load("ViT-B/32", device=device)

    initial_class_names = ["door", "window", "flooring", "lumber"]
    class_captions = [f"An image depicting a {x}" for x in initial_class_names]
    text_input = clip.tokenize(class_captions).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_input).float()
        text_features /= text_features.norm(dim=-1, keepdim=True)

    fld = folder

    dataset = ImageFolder(root=fld, transform=transform)
    data_batches = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
    image_input, y_true = next(iter(data_batches))
    image_input = image_input.to(device)
    with torch.no_grad():
        image_features = model.encode_image(image_input).float()
    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    text_probs = text_probs.cpu()

    files = sorted(os.listdir(folder))

    for i, (image, label_idx) in enumerate(dataset):
        max_prob = 0
        element = []
        idx = ""
        for j in range(len(initial_class_names)):
            if text_probs[i][j].item() > max_prob:
                max_prob = text_probs[i][j].item()
                idx = j
        element = initial_class_names[idx]

        probs_data = clip_classify(folder, element, i)[0]
        probs_recom = clip_classify(folder, element, i)[1]
        print("It's ", element, " ", probs_recom)
예제 #12
0
 def test_title(self, title: str) -> Optional[Any]:
     with torch.no_grad():
         tensor = self.network.encode_text(
             clip.tokenize([title,
                            TEMPLATE.format(title=title)]).to(DEVICE))
     matrix = tensor.cpu().numpy()
     expected_title = matrix[0, :]
     expected_context = matrix[1, :]
     actual_title, actual_context = self.title2vec(title)
     if np.array_equal(expected_title, actual_title) and np.array_equal(
             expected_context, actual_context):
         return None
     else:
         if not np.array_equal(expected_title,
                               actual_title) and not np.array_equal(
                                   expected_context, actual_context):
             return (expected_title, actual_title, expected_context,
                     actual_context)
         elif not np.array_equal(expected_title, actual_title):
             return (expected_title, actual_title, None, None)
         elif not np.array_equal(expected_context, actual_context):
             return (None, None, expected_context, actual_context)
     return None
예제 #13
0
    def process_input(self, image, labels):
        """Creates a probability distribution of image over labels.
        
        Args:
            image<str>: Path of the image to be processed.
            labels<list>: List of labels over which to predict.

        Returns:
            dict('probs', 'labels')
        """
        image = self.preprocess(Image.open(image)).unsqueeze(0).to(self.device)
        text = clip.tokenize(labels).to(self.device)

        with torch.no_grad():
            image_features = self.model.encode_image(image)
            text_features = self.model.encode_text(text)

            logits_per_image, logits_per_text = self.model(image, text)
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()

        probs = probs[0]
        output = {label: float(prob) for prob, label in zip(probs, labels)}

        return output
예제 #14
0
def approach(
    G,
    *,
    num_steps                  = 100,
    w_avg_samples              = 10000,
    initial_learning_rate      = 0.02,
    initial_noise_factor       = 0.02,
    noise_floor                = 0.02,
    psi                        = 0.8,
    noise_ramp_length          = 1.0, # was 0.75
    regularize_noise_weight    = 10000, # was 1e5
    seed                       = 69097,
    noise_opt                  = True,
    ws                         = None,
    text                       = 'a computer generated image',
    device: torch.device
):

    '''
    local_args = dict(locals())
    params = []
    for x in local_args:
        if x != 'G' and x != 'device':
            print(x,':',local_args[x])
            params.append({x:local_args[x]})
    print(json.dumps(params))
    '''

    G = copy.deepcopy(G).eval().requires_grad_(False).to(device)

    lr = initial_learning_rate

    '''
    # Compute w stats.
    logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
    z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
    #w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)  # [N, L, C]
    w_samples = G.mapping(torch.from_numpy(z_samples).to(device),  None, truncation_psi=0.8)  # [N, L, C]
    w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)       # [N, 1, C]
    w_avg = np.mean(w_samples, axis=0, keepdims=True)      # [1, 1, C]
    w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
    '''

    # derive W from seed
    if ws is None:
        print('Generating w for seed %i' % seed )
        z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
        w_samples = G.mapping(z,  None, truncation_psi=psi)
        w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)
        w_avg = np.mean(w_samples, axis=0, keepdims=True)
    else:
        w_samples = torch.tensor(ws, device=device)
        w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32)
        w_avg = np.mean(w_samples, axis=0, keepdims=True)
    #w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
    w_std = 2 # ~9.9 for portraits network. should compute if using median median

    # Setup noise inputs.
    noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
    w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
    w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)

    if noise_opt:
        optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
        print('optimizer: w + noise')
    else:
        optimizer = torch.optim.Adam([w_opt] , betas=(0.9, 0.999), lr=initial_learning_rate)
        print('optimizer: w')

    # Init noise.
    for buf in noise_bufs.values():
        buf[:] = torch.randn_like(buf)
        buf.requires_grad = True

    # Load the perceptor
    print('Loading perceptor for text:', text)
    perceptor, preprocess = clip.load('ViT-B/32', jit=True)
    perceptor = perceptor.eval()
    tx = clip.tokenize(text)
    whispers = perceptor.encode_text(tx.cuda()).detach().clone()

    # Descend
    for step in range(num_steps):
        # noise schedule
        t = step / num_steps
        w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2

        # floor
        if w_noise_scale < noise_floor:
            w_noise_scale = noise_floor

        # lr schedule is disabled
        '''
        lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
        lr = initial_learning_rate * lr_ramp
        '''

        ''' for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        '''

        # do G.synthesis
        w_noise = torch.randn_like(w_opt) * w_noise_scale
        ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
        synth_images = G.synthesis(ws, noise_mode='const')

        #save1
        '''
        synth_images_save = (synth_images + 1) * (255/2)
        synth_images_save = synth_images_save.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
        PIL.Image.fromarray(synth_images_save, 'RGB').save('project/test1.png')
        '''

        nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        into = synth_images
        into = nom(into) # normalize copied from CLIP preprocess. doesn't seem to affect tho

        # scale to CLIP input size
        into = torch.nn.functional.interpolate(synth_images, (224,224), mode='bilinear', align_corners=True)

        # CLIP expects [1, 3, 224, 224], so we should be fine
        glimmers = perceptor.encode_image(into)
        away =  -30 * torch.cosine_similarity(whispers, glimmers, dim = -1).mean() # Dunno why 30 works lol

        # noise reg, from og projector
        reg_loss = 0.0
        for v in noise_bufs.values():
            noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
            while True:
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
                reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
                if noise.shape[2] <= 8:
                    break
                noise = F.avg_pool2d(noise, kernel_size=2)

        if noise_opt:
            loss = away + reg_loss * regularize_noise_weight
        else:
            loss = away

        # Step
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        print(f'step {step+1:>4d}/{num_steps}:  loss {float(loss):<5.2f} ','lr', lr, f'noise scale: {float(w_noise_scale):<5.6f}',f'away: {float(away / (-30)):<5.6f}')

        w_out[step] = w_opt.detach()[0]

        # Normalize noise.
        with torch.no_grad():
            for buf in noise_bufs.values():
                buf -= buf.mean()
                buf *= buf.square().mean().rsqrt()

    return w_out.repeat([1, G.mapping.num_ws, 1])
예제 #15
0
def tokenize_text(search_query):
    search_query = clip.tokenize(search_query).to(device)
    return search_query
예제 #16
0
    def process(txt, num):

        global params_start
        params, image_f = fft_image([1, 3, *a.size], resume='init.pt')
        image_f = to_valid_rgb(image_f)
        optimizer = torch.optim.Adam(params, a.lrate)

        if a.verbose is True: print(' ref text: ', txt)
        if a.translate:
            translator = Translator()
            txt = translator.translate(txt, dest='en').text
            if a.verbose is True: print(' translated to:', txt)
        tx = clip.tokenize(txt).cuda()
        txt_enc = model_clip.encode_text(tx).detach().clone()

        out_name = '%03d-%s' % (num + 1, txt_clean(txt))
        out_name += '-%s' % a.model if 'RN' in a.model.upper() else ''
        tempdir = os.path.join(workdir, out_name)
        os.makedirs(tempdir, exist_ok=True)

        pbar = ProgressBar(a.steps // a.fstep)
        for i in range(a.steps):
            loss = 0

            noise = a.noise * torch.randn(1, 1, *params[0].shape[2:4],
                                          1).cuda() if a.noise > 0 else None
            img_out = image_f(noise)

            imgs_sliced = slice_imgs([img_out],
                                     a.samples,
                                     a.modsize,
                                     norm_in,
                                     a.overscan,
                                     micro=None)
            out_enc = model_clip.encode_image(imgs_sliced[-1])
            loss -= torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean()
            if a.in_txt0 is not None:  # subtract text
                loss += torch.cosine_similarity(txt_enc0, out_enc,
                                                dim=-1).mean()
            del img_out, imgs_sliced, out_enc
            torch.cuda.empty_cache()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % a.fstep == 0:
                with torch.no_grad():
                    img = image_f(contrast=a.contrast).cpu().numpy()[0]
                checkout(img,
                         os.path.join(tempdir, '%04d.jpg' % (i // a.fstep)),
                         verbose=a.verbose)
                pbar.upd()
                del img

        if a.keep == 'all':
            params_start = ema(params_start, params[0].detach(), num + 1)
            torch.save(params_start, 'init.pt')
        elif a.keep == 'last':
            torch.save((params_start + params[0].detach()) / 2, 'init.pt')

        torch.save(params[0], '%s.pt' % os.path.join(workdir, out_name))
        shutil.copy(
            img_list(tempdir)[-1],
            os.path.join(workdir, '%s-%d.jpg' % (out_name, a.steps)))
        os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' %
                  (tempdir, os.path.join(workdir, out_name)))
예제 #17
0
def main():
    a = get_args()

    # Load CLIP models
    model_clip, _ = clip.load(a.model)
    if a.verbose is True: print(' using model', a.model)
    xmem = {'RN50': 0.5, 'RN50x4': 0.16, 'RN101': 0.33}
    if 'RN' in a.model:
        a.samples = int(a.samples * xmem[a.model])
    workdir = os.path.join(a.out_dir, basename(a.in_txt))
    workdir += '-%s' % a.model if 'RN' in a.model.upper() else ''
    os.makedirs(workdir, exist_ok=True)

    norm_in = torchvision.transforms.Normalize(
        (0.48145466, 0.4578275, 0.40821073),
        (0.26862954, 0.26130258, 0.27577711))

    if a.in_txt0 is not None:
        if a.verbose is True: print(' subtract text:', basename(a.in_txt0))
        if a.translate:
            translator = Translator()
            a.in_txt0 = translator.translate(a.in_txt0, dest='en').text
            if a.verbose is True: print(' translated to:', a.in_txt0)
        tx0 = clip.tokenize(a.in_txt0).cuda()
        txt_enc0 = model_clip.encode_text(tx0).detach().clone()

    # make init
    global params_start
    params_shape = [1, 3, a.size[0], a.size[1] // 2 + 1, 2]
    params_start = torch.randn(*params_shape).cuda()  # random init

    if a.resume is not None and os.path.isfile(a.resume):
        if a.verbose is True: print(' resuming from', a.resume)
        params, _ = fft_image([1, 3, *a.size], resume=a.resume)
        params_start = ema(params_start, params[0].detach(), 1)
    else:
        a.resume = 'init.pt'

    shutil.copy(a.resume,
                os.path.join(workdir, '000-%s.pt' % basename(a.resume)))
    torch.save(params_start, 'init.pt')  # final init

    def process(txt, num):

        global params_start
        params, image_f = fft_image([1, 3, *a.size], resume='init.pt')
        image_f = to_valid_rgb(image_f)
        optimizer = torch.optim.Adam(params, a.lrate)

        if a.verbose is True: print(' ref text: ', txt)
        if a.translate:
            translator = Translator()
            txt = translator.translate(txt, dest='en').text
            if a.verbose is True: print(' translated to:', txt)
        tx = clip.tokenize(txt).cuda()
        txt_enc = model_clip.encode_text(tx).detach().clone()

        out_name = '%03d-%s' % (num + 1, txt_clean(txt))
        out_name += '-%s' % a.model if 'RN' in a.model.upper() else ''
        tempdir = os.path.join(workdir, out_name)
        os.makedirs(tempdir, exist_ok=True)

        pbar = ProgressBar(a.steps // a.fstep)
        for i in range(a.steps):
            loss = 0

            noise = a.noise * torch.randn(1, 1, *params[0].shape[2:4],
                                          1).cuda() if a.noise > 0 else None
            img_out = image_f(noise)

            imgs_sliced = slice_imgs([img_out],
                                     a.samples,
                                     a.modsize,
                                     norm_in,
                                     a.overscan,
                                     micro=None)
            out_enc = model_clip.encode_image(imgs_sliced[-1])
            loss -= torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean()
            if a.in_txt0 is not None:  # subtract text
                loss += torch.cosine_similarity(txt_enc0, out_enc,
                                                dim=-1).mean()
            del img_out, imgs_sliced, out_enc
            torch.cuda.empty_cache()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % a.fstep == 0:
                with torch.no_grad():
                    img = image_f(contrast=a.contrast).cpu().numpy()[0]
                checkout(img,
                         os.path.join(tempdir, '%04d.jpg' % (i // a.fstep)),
                         verbose=a.verbose)
                pbar.upd()
                del img

        if a.keep == 'all':
            params_start = ema(params_start, params[0].detach(), num + 1)
            torch.save(params_start, 'init.pt')
        elif a.keep == 'last':
            torch.save((params_start + params[0].detach()) / 2, 'init.pt')

        torch.save(params[0], '%s.pt' % os.path.join(workdir, out_name))
        shutil.copy(
            img_list(tempdir)[-1],
            os.path.join(workdir, '%s-%d.jpg' % (out_name, a.steps)))
        os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' %
                  (tempdir, os.path.join(workdir, out_name)))

    with open(a.in_txt, 'r', encoding="utf-8") as f:
        texts = f.readlines()
        texts = [
            tt.strip() for tt in texts if len(tt.strip()) > 0 and tt[0] != '#'
        ]
    if a.verbose is True:
        print(' total lines:', len(texts))
        print(' samples:', a.samples)

    for i, txt in enumerate(texts):
        process(txt, i)

    vsteps = int(a.length * 25 / len(texts))  # 25 fps
    tempdir = os.path.join(workdir, '_final')
    os.makedirs(tempdir, exist_ok=True)

    def read_pt(file):
        return torch.load(file).cuda()

    if a.verbose is True: print(' rendering complete piece')
    ptfiles = file_list(workdir, 'pt')
    pbar = ProgressBar(vsteps * len(ptfiles))
    for px in range(len(ptfiles)):
        params1 = read_pt(ptfiles[px])
        params2 = read_pt(ptfiles[(px + 1) % len(ptfiles)])

        params, image_f = fft_image([1, 3, *a.size], resume=params1)
        image_f = to_valid_rgb(image_f)

        for i in range(vsteps):
            with torch.no_grad():
                img = image_f(
                    (params2 - params1) *
                    math.sin(1.5708 * i / vsteps)**2)[0].permute(1, 2, 0)
                img = torch.clip(img * 255, 0,
                                 255).cpu().numpy().astype(np.uint8)
            imsave(os.path.join(tempdir, '%05d.jpg' % (px * vsteps + i)), img)
            if a.verbose is True: cvshow(img)
            pbar.upd()

    os.system('ffmpeg -v warning -y -i %s\%%05d.jpg "%s.mp4"' %
              (tempdir, os.path.join(a.out_dir, basename(a.in_txt))))
    if a.keep is True: os.remove('init.pt')
예제 #18
0
    def process(txt, num):

        sd = 0.01
        if a.keep > 0: sd = a.keep + (1-a.keep) * sd
        params, image_f = fft_image([1, 3, *a.size], resume='init.pt', sd=sd, decay_power=a.decay)
        image_f = to_valid_rgb(image_f, colors = a.colors)

        if a.prog is True:
            lr1 = a.lrate * 2
            lr0 = a.lrate * 0.1
        else:
            lr0 = a.lrate
        optimizer = torch.optim.Adam(params, lr0)
    
        if a.verbose is True: print(' ref text: ', txt)
        if a.translate:
            translator = Translator()
            txt = translator.translate(txt, dest='en').text
            if a.verbose is True: print(' translated to:', txt)
        if a.multilang is True:
            model_lang = SentenceTransformer('clip-ViT-B-32-multilingual-v1').cuda()
            txt_enc = model_lang.encode([txt], convert_to_tensor=True, show_progress_bar=False).detach().clone()
            del model_lang
        else:
            txt_enc = model_clip.encode_text(clip.tokenize(txt).cuda()).detach().clone()
        if a.notext > 0:
            txt_plot = torch.from_numpy(plot_text(txt, a.modsize)/255.).unsqueeze(0).permute(0,3,1,2).cuda()
            txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone()
        else: txt_plot_enc = None

        out_name = '%03d-%s' % (num+1, txt_clean(txt))
        out_name += '-%s' % a.model if 'RN' in a.model.upper() else ''
        tempdir = os.path.join(workdir, out_name)
        os.makedirs(tempdir, exist_ok=True)
        
        pbar = ProgressBar(a.steps // a.fstep)
        for i in range(a.steps):
            loss = 0

            noise = a.noise * torch.randn(1, 1, *params[0].shape[2:4], 1).cuda() if a.noise > 0 else None
            img_out = image_f(noise)
            
            if a.sharp != 0:
                lx = torch.mean(torch.abs(img_out[0,:,:,1:] - img_out[0,:,:,:-1]))
                ly = torch.mean(torch.abs(img_out[0,:,1:,:] - img_out[0,:,:-1,:]))
                loss -= a.sharp * (ly+lx)

            imgs_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, micro=1.)
            out_enc = model_clip.encode_image(imgs_sliced[-1])
            loss -= torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean()
            if a.notext > 0:
                loss += a.notext * torch.cosine_similarity(txt_plot_enc, out_enc, dim=-1).mean()
            if a.diverse != 0:
                imgs_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, micro=1.)
                out_enc2 = model_clip.encode_image(imgs_sliced[-1])
                loss += a.diverse * torch.cosine_similarity(out_enc, out_enc2, dim=-1).mean()
                del out_enc2; torch.cuda.empty_cache()
            if a.expand > 0:
                global prev_enc
                if i > 0:
                    loss += a.expand * torch.cosine_similarity(out_enc, prev_enc, dim=-1).mean()
                prev_enc = out_enc.detach().clone()
            if a.in_txt0 is not None: # subtract text
                loss += torch.cosine_similarity(txt_enc0, out_enc, dim=-1).mean()
            del img_out, imgs_sliced, out_enc; torch.cuda.empty_cache()

            if a.prog is True:
                lr_cur = lr0 + (i / a.steps) * (lr1 - lr0)
                for g in optimizer.param_groups: 
                    g['lr'] = lr_cur
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % a.fstep == 0:
                with torch.no_grad():
                    img = image_f(contrast=a.contrast).cpu().numpy()[0]
                if a.sharp != 0:
                    img = img **1.3 # empirical tone mapping
                checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.fstep)), verbose=a.verbose)
                pbar.upd()
                del img

        if a.keep > 0:
            global params_start, params_ema
            params_ema = ema(params_ema, params[0].detach().clone(), num+1)
            torch.save((1-a.keep) * params_start + a.keep * params_ema, 'init.pt')
        
        torch.save(params[0], '%s.pt' % os.path.join(workdir, out_name))
        shutil.copy(img_list(tempdir)[-1], os.path.join(workdir, '%s-%d.jpg' % (out_name, a.steps)))
        os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(workdir, out_name)))
예제 #19
0
    def embedding(self, split_dir, model):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            if cfg.GPU_ID != -1:
                netG.cuda()
            netG.eval()
            #
            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
            img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder',
                                                       'image_encoder')
            print(img_encoder_path)
            print('Load image encoder from:', img_encoder_path)
            state_dict = \
                torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            if cfg.GPU_ID != -1:
                image_encoder = image_encoder.cuda()
            image_encoder.eval()

            print('Load text encoder from:', cfg.TRAIN.NET_E)
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            if cfg.GPU_ID != -1:
                text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM

            with torch.no_grad():
                noise = Variable(torch.FloatTensor(batch_size, nz))
                if cfg.GPU_ID != -1:
                    noise = noise.cuda()

            # the path to save generated images
            save_dir = model_dir[:model_dir.rfind('.pth')]

            cnt = 0

            # new
            if cfg.TRAIN.CLIP_SENTENCODER:
                print("Use CLIP SentEncoder for sampling")
            img_features = dict()
            txt_features = dict()

            with torch.no_grad():
                for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                    for step, data in enumerate(self.data_loader, 0):
                        cnt += batch_size
                        if step % 100 == 0:
                            print('step: ', step)

                        imgs, captions, cap_lens, class_ids, keys, texts = prepare_data(
                            data)

                        hidden = text_encoder.init_hidden(batch_size)
                        # words_embs: batch_size x nef x seq_len
                        # sent_emb: batch_size x nef
                        words_embs, sent_emb = text_encoder(
                            captions, cap_lens, hidden)
                        words_embs, sent_emb = words_embs.detach(
                        ), sent_emb.detach()
                        mask = (captions == 0)
                        num_words = words_embs.size(2)
                        if mask.size(1) > num_words:
                            mask = mask[:, :num_words]

                        if cfg.TRAIN.CLIP_SENTENCODER:

                            # random select one paragraph for each training example
                            sents = []
                            for idx in range(len(texts)):
                                sents_per_image = texts[idx].split(
                                    '\n')  # new 3/11
                                if len(sents_per_image) > 1:
                                    sent_ix = np.random.randint(
                                        0,
                                        len(sents_per_image) - 1)
                                else:
                                    sent_ix = 0
                                sents.append(sents_per_image[0])
                            # print('sents: ', sents)

                            sent = clip.tokenize(sents)  # .to(device)

                            # load clip
                            #model = torch.jit.load("model.pt").cuda().eval()
                            sent_input = sent
                            if cfg.GPU_ID != -1:
                                sent_input = sent.cuda()
                            # print("text input", sent_input)
                            sent_emb_clip = model.encode_text(
                                sent_input).float()
                            if CLIP:
                                sent_emb = sent_emb_clip
                        #######################################################
                        # (2) Generate fake images
                        ######################################################
                        noise.data.normal_(0, 1)
                        fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs,
                                                  mask)
                        if CLIP:
                            images = []
                            for j in range(fake_imgs[-1].shape[0]):
                                image = fake_imgs[-1][j].cpu().clone()
                                image = image.squeeze(0)
                                unloader = transforms.ToPILImage()
                                image = unloader(image)

                                image = preprocess(
                                    image.convert("RGB"))  # 256*256 -> 224*224
                                images.append(image)

                            image_mean = torch.tensor(
                                [0.48145466, 0.4578275, 0.40821073]).cuda()
                            image_std = torch.tensor(
                                [0.26862954, 0.26130258, 0.27577711]).cuda()

                            image_input = torch.tensor(np.stack(images)).cuda()
                            image_input -= image_mean[:, None, None]
                            image_input /= image_std[:, None, None]
                            cnn_codes = model.encode_image(image_input).float()
                        else:
                            region_features, cnn_codes = image_encoder(
                                fake_imgs[-1])
                        for j in range(batch_size):
                            cnn_code = cnn_codes[j]

                            temp = keys[j].replace('b', '').replace("'", '')
                            img_features[temp] = cnn_code.cpu().numpy()
                            txt_features[temp] = sent_emb[j].cpu().numpy()
            with open(save_dir + ".pkl", 'wb') as f:
                pickle.dump(img_features, f)
            with open(save_dir + "_text.pkl", 'wb') as f:
                pickle.dump(txt_features, f)
예제 #20
0
    def train(self, model):
        text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models(
        )  #load encoder
        avg_param_G = copy_G_params(netG)
        optimizerG, optimizersD = self.define_optimizers(netG, netsD)
        real_labels, fake_labels, match_labels = self.prepare_labels()

        batch_size = self.batch_size
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        gen_iterations = 0
        # gen_iterations = start_epoch * self.num_batches

        if cfg.TRAIN.CLIP_SENTENCODER:
            print("CLIP Sentence Encoder: True")

        if cfg.TRAIN.CLIP_LOSS:
            print("CLIP Loss: True")

        if cfg.TRAIN.EXTRA_LOSS:
            print("Extra DAMSM Loss in G: True")
            print("DAMSM Weight: ", cfg.TRAIN.WEIGHT_DAMSM_LOSS)

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            data_iter = iter(self.data_loader)
            step = 0
            while step < self.num_batches:
                # reset requires_grad to be trainable for all Ds
                # self.set_requires_grad_value(netsD, True)

                ######################################################
                # (1) Prepare training data and Compute text embeddings
                ######################################################
                data = data_iter.next()

                # imgs, captions, cap_lens, class_ids, keys = prepare_data(data) #new sents:, sents
                # new: return raw texts
                imgs, captions, cap_lens, class_ids, keys, texts = prepare_data(
                    data)

                hidden = text_encoder.init_hidden(batch_size)
                # words_embs: batch_size x nef x seq_len
                # sent_emb: batch_size x nef
                # new: rename
                words_embs_damsm, sent_emb_damsm = text_encoder(
                    captions, cap_lens, hidden)
                #print('captions shape from trainer: ', captions.shape) torch.Size([12, 18])
                #print('sentence emb size: ', sent_emb.shape) torch.Size([12, 256])
                words_embs_damsm, sent_emb_damsm = words_embs_damsm.detach(
                ), sent_emb_damsm.detach()
                #print('sentence emb size after detach: ', sent_emb[0]) torch.Size([12, 256])
                mask = (captions == 0)
                num_words = words_embs_damsm.size(2)
                if mask.size(1) > num_words:
                    mask = mask[:, :num_words]

                # new: use clip sentence encoder
                if cfg.TRAIN.CLIP_SENTENCODER or cfg.TRAIN.CLIP_LOSS:
                    sents = []
                    # randomly select one paragraph for each training example
                    for idx in range(len(texts)):
                        sents_per_image = texts[idx].split(
                            '\n')  #new: '\n' rather than '.'
                        if len(sents_per_image) > 1:
                            sent_ix = np.random.randint(
                                0,
                                len(sents_per_image) - 1)
                        else:
                            sent_ix = 0
                        sents.append(sents_per_image[sent_ix])
                    #print('sents: ', sents)

                    sent = clip.tokenize(sents)  #.to(device)

                    # load clip
                    #model = torch.jit.load("model.pt").cuda().eval()    # ViT-B/32
                    sent_input = sent.cuda()

                    with torch.no_grad():
                        sent_emb_clip = model.encode_text(sent_input).float()
                        if cfg.TRAIN.CLIP_SENTENCODER:
                            sent_emb = sent_emb_clip
                        else:
                            sent_emb = sent_emb_damsm
                else:
                    sent_emb_clip = 0
                    sent_emb = sent_emb_damsm

                words_embs = words_embs_damsm

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs,
                                                mask)

                #######################################################
                # (3) Update D network
                ######################################################
                errD_total = 0
                D_logs = ''
                for i in range(len(netsD)):
                    netsD[i].zero_grad()
                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                              sent_emb, real_labels,
                                              fake_labels)
                    # backward and update parameters
                    errD.backward()
                    optimizersD[i].step()
                    errD_total += errD
                    D_logs += 'errD%d: %.2f ' % (i, errD.item())

                #######################################################
                # (4) Update G network: maximize log(D(G(z)))
                ######################################################
                # compute total loss for training G
                step += 1
                gen_iterations += 1

                # do not need to compute gradient for Ds
                # self.set_requires_grad_value(netsD, False)
                netG.zero_grad()

                # new: pass clip model and sent_emb_damsm for CLIP_LOSS = True
                errG_total, G_logs = \
                    generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                                        words_embs, sent_emb, match_labels, cap_lens, class_ids, model, sent_emb_damsm, sent_emb_clip)

                kl_loss = KL_loss(mu, logvar)
                errG_total += kl_loss
                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                # backward and update parameters
                errG_total.backward()
                optimizerG.step()
                for p, avg_p in zip(netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                if gen_iterations % 100 == 0:
                    print(D_logs + '\n' + G_logs)
                # save images
                if gen_iterations % 1000 == 0:
                    backup_para = copy_G_params(netG)
                    load_params(netG, avg_param_G)
                    self.save_img_results(netG,
                                          fixed_noise,
                                          sent_emb,
                                          words_embs,
                                          mask,
                                          image_encoder,
                                          captions,
                                          cap_lens,
                                          epoch,
                                          name='average')
                    load_params(netG, backup_para)
                    #
                    # self.save_img_results(netG, fixed_noise, sent_emb,
                    #                       words_embs, mask, image_encoder,
                    #                       captions, cap_lens,
                    #                       epoch, name='current')
            end_t = time.time()

            print('''[%d/%d][%d]
                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.item(),
                   errG_total.item(), end_t - start_t))

            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or epoch % 10 == 0:  # and epoch != 0:
                self.save_model(netG, avg_param_G, netsD, epoch)

        self.save_model(netG, avg_param_G, netsD, self.max_epoch)
예제 #21
0
    def sampling(self, split_dir, model):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            if split_dir == 'test':
                split_dir = 'valid'
            # Build and load the generator
            if cfg.GAN.B_DCGAN:
                netG = G_DCGAN()
            else:
                netG = G_NET()
            netG.apply(weights_init)
            if cfg.GPU_ID != -1:
                netG.cuda()
            netG.eval()
            #
            text_encoder = RNN_ENCODER(self.n_words,
                                       nhidden=cfg.TEXT.EMBEDDING_DIM)
            state_dict = \
                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
            text_encoder.load_state_dict(state_dict)
            print('Load text encoder from:', cfg.TRAIN.NET_E)
            if cfg.GPU_ID != -1:
                text_encoder = text_encoder.cuda()
            text_encoder.eval()

            batch_size = self.batch_size
            nz = cfg.GAN.Z_DIM

            with torch.no_grad():
                noise = Variable(torch.FloatTensor(batch_size, nz))
                if cfg.GPU_ID != -1:
                    noise = noise.cuda()

            model_dir = cfg.TRAIN.NET_G
            state_dict = \
                torch.load(model_dir, map_location=lambda storage, loc: storage)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            netG.load_state_dict(state_dict)
            print('Load G from: ', model_dir)

            # the path to save generated images
            s_tmp = model_dir[:model_dir.rfind('.pth')]
            save_dir = '%s/%s' % (s_tmp, split_dir)
            mkdir_p(save_dir)

            cnt = 0

            #new
            if cfg.TRAIN.CLIP_SENTENCODER:
                print("Use CLIP SentEncoder for sampling")

            for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
                for step, data in enumerate(self.data_loader, 0):
                    cnt += batch_size
                    if step % 100 == 0:
                        print('step: ', step)
                    # if step > 50:
                    #     break

                    #imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
                    #new
                    imgs, captions, cap_lens, class_ids, keys, texts = prepare_data(
                        data)

                    hidden = text_encoder.init_hidden(batch_size)
                    # words_embs: batch_size x nef x seq_len
                    # sent_emb: batch_size x nef
                    words_embs, sent_emb = text_encoder(
                        captions, cap_lens, hidden)
                    words_embs, sent_emb = words_embs.detach(
                    ), sent_emb.detach()
                    mask = (captions == 0)
                    num_words = words_embs.size(2)
                    if mask.size(1) > num_words:
                        mask = mask[:, :num_words]

                    # new
                    if cfg.TRAIN.CLIP_SENTENCODER:

                        # random select one paragraph for each training example
                        sents = []
                        for idx in range(len(texts)):
                            sents_per_image = texts[idx].split(
                                '\n')  # new 3/11
                            if len(sents_per_image) > 1:
                                sent_ix = np.random.randint(
                                    0,
                                    len(sents_per_image) - 1)
                            else:
                                sent_ix = 0
                            sents.append(sents_per_image[sent_ix])
                            with open('%s/%s' % (save_dir, 'eval_sents.txt'),
                                      'a+') as f:
                                f.write(sents_per_image[sent_ix] + '\n')
                        # print('sents: ', sents)

                        sent = clip.tokenize(sents)  # .to(device)

                        # load clip
                        #model = torch.jit.load("model.pt").cuda().eval()
                        sent_input = sent
                        if cfg.GPU_ID != -1:
                            sent_input = sent.cuda()
                        # print("text input", sent_input)
                        with torch.no_grad():
                            sent_emb = model.encode_text(sent_input).float()

                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs,
                                              mask)
                    for j in range(batch_size):
                        s_tmp = '%s/fake/%s' % (save_dir, keys[j])
                        folder = s_tmp[:s_tmp.rfind('/')]
                        if not os.path.isdir(folder):
                            print('Make a new folder: ', folder)
                            mkdir_p(folder)
                            print('Make a new folder: ', f'{save_dir}/real')
                            mkdir_p(f'{save_dir}/real')
                            print('Make a new folder: ', f'{save_dir}/text')
                            mkdir_p(f'{save_dir}/text')
                        k = -1
                        # for k in range(len(fake_imgs)):
                        im = fake_imgs[k][j].data.cpu().numpy()
                        # [-1, 1] --> [0, 255]
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)
                        fullpath = '%s_s%d.png' % (s_tmp, k)
                        im.save(fullpath)
                        temp = keys[j].replace('b', '').replace("'", '')
                        shutil.copy(f"../data/Face/images/{temp}.jpg",
                                    f"{save_dir}/real/")
                        shutil.copy(f"../data/Face/text/{temp}.txt",
                                    f"{save_dir}/text/")
예제 #22
0
def encode_text(network, titles):
    text = clip.tokenize(titles).to(DEVICE)
    with torch.no_grad():
        tensor = network.encode_text(text)
    return tensor.cpu().numpy()
예제 #23
0
def main(args):
    text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
    os.makedirs(args.results_dir, exist_ok=True)

    g_ema = Generator(args.size, 512, 8)
    g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
    g_ema.eval()
    g_ema = g_ema.cuda()
    mean_latent = g_ema.mean_latent(4096)

    if args.latent_path:
        latent_code_init = torch.load(args.latent_path).cuda()
    elif args.mode == "edit":
        latent_code_init_not_trunc = torch.randn(1, 512).cuda()
        with torch.no_grad():
            _, latent_code_init = g_ema([latent_code_init_not_trunc],
                                        return_latents=True,
                                        truncation=args.truncation,
                                        truncation_latent=mean_latent)
    else:
        latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)

    latent = latent_code_init.detach().clone()
    latent.requires_grad = True

    clip_loss = CLIPLoss()

    optimizer = optim.Adam([latent], lr=args.lr)

    pbar = tqdm(range(args.step))

    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]["lr"] = lr

        img_gen, _ = g_ema([latent],
                           input_is_latent=True,
                           randomize_noise=False)

        c_loss = clip_loss(img_gen, text_inputs)

        if args.mode == "edit":
            #l2_loss = ((latent_code_init - latent) ** 2).sum()
            #loss = c_loss + args.l2_lambda * l2_loss
            loss = c_loss
        else:
            loss = c_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description((f"loss: {loss.item():.4f};"))
        if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
            with torch.no_grad():
                img_gen, _ = g_ema([latent],
                                   input_is_latent=True,
                                   randomize_noise=False)

            torchvision.utils.save_image(img_gen,
                                         f"results/{str(i).zfill(5)}.png",
                                         normalize=True,
                                         range=(-1, 1))

    if args.mode == "edit":
        with torch.no_grad():
            img_orig, _ = g_ema([latent_code_init],
                                input_is_latent=True,
                                randomize_noise=False)

        final_result = torch.cat([img_orig, img_gen])
    else:
        final_result = img_gen

    return final_result
예제 #24
0
    def predict_classes(self,
                        image_path,
                        labels=None,
                        top=15,
                        print_results=True):
        """
        Print predicted clases for a given image using OpenAI Clip

        Parameters
        ----------
        image_path : str
            The file location

        labels: list[str]
            list containing the possible labels

        top: int
            Number of labels to select

        print_results: boolean
            Print the top results

        """
        if not labels:
            labels = self.imagenet_labels

        # Transform image to PIL
        #img = cv2.imread(image_path)
        img = self.read_image(image_name=image_path,
                              ACCESS_ID=self.ACCESS_ID,
                              ACCESS_KEY=self.ACCESS_KEY,
                              bucket='carlo-computer-vision-project')
        img = np.array(img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        im_pil = Image.fromarray(img)
        im_pil

        # Prepare the inputs
        image_input = self.preprocess(im_pil).unsqueeze(0).to(self.device)
        text_inputs = torch.cat([
            clip.tokenize(f"a photo of a {c}") for c in labels
        ]).to(self.device)

        # Calculate features
        with torch.no_grad():
            image_features = self.model.encode_image(image_input)
            text_features = self.model.encode_text(text_inputs)

        # Pick the top most similar labels for the image
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_features @ text_features.T
                      )  #.softmax(dim=-1)
        values, indices = similarity[0].topk(min(top, similarity[0].shape[0]))

        prediction_set = set()

        # Print predictions
        if print_results:
            print("Top predictions: \n")
        for idx, (value, index) in enumerate(zip(values, indices)):
            prediction_set.add(labels[index].title())
            if print_results:
                print("{:02d}. {} - Score: {:.2f}".format(
                    idx + 1, labels[index].title(), value.item()))
        return prediction_set
예제 #25
0
def main(args):
    ensure_checkpoint_exists(args.ckpt)
    text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
    os.makedirs(args.results_dir, exist_ok=True)

    g_ema = Generator(args.stylegan_size, 512, 8)
    g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
    g_ema.eval()
    g_ema = g_ema.cuda()
    mean_latent = g_ema.mean_latent(4096)

    if args.latent_path:
        latent_code_init = torch.load(args.latent_path).cuda()
    elif args.mode == "edit":
        latent_code_init_not_trunc = torch.randn(1, 512).cuda()
        with torch.no_grad():
            _, latent_code_init, _ = g_ema([latent_code_init_not_trunc],
                                           return_latents=True,
                                           truncation=args.truncation,
                                           truncation_latent=mean_latent)
    else:
        latent_code_init = mean_latent.detach().clone().repeat(1, 18, 1)

    with torch.no_grad():
        img_orig, _ = g_ema([latent_code_init],
                            input_is_latent=True,
                            randomize_noise=False)

    if args.work_in_stylespace:
        with torch.no_grad():
            _, _, latent_code_init = g_ema([latent_code_init],
                                           input_is_latent=True,
                                           return_latents=True)
        latent = [s.detach().clone() for s in latent_code_init]
        for c, s in enumerate(latent):
            if c in STYLESPACE_INDICES_WITHOUT_TORGB:
                s.requires_grad = True
    else:
        latent = latent_code_init.detach().clone()
        latent.requires_grad = True

    clip_loss = CLIPLoss(args)
    id_loss = IDLoss(args)

    if args.work_in_stylespace:
        optimizer = optim.Adam(latent, lr=args.lr)
    else:
        optimizer = optim.Adam([latent], lr=args.lr)

    pbar = tqdm(range(args.step))

    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]["lr"] = lr

        img_gen, _ = g_ema([latent],
                           input_is_latent=True,
                           randomize_noise=False,
                           input_is_stylespace=args.work_in_stylespace)

        c_loss = clip_loss(img_gen, text_inputs)

        if args.id_lambda > 0:
            i_loss = id_loss(img_gen, img_orig)[0]
        else:
            i_loss = 0

        if args.mode == "edit":
            if args.work_in_stylespace:
                l2_loss = sum([((latent_code_init[c] - latent[c])**2).sum()
                               for c in range(len(latent_code_init))])
            else:
                l2_loss = ((latent_code_init - latent)**2).sum()
            loss = c_loss + args.l2_lambda * l2_loss + args.id_lambda * i_loss
        else:
            loss = c_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description((f"loss: {loss.item():.4f};"))
        if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
            with torch.no_grad():
                img_gen, _ = g_ema([latent],
                                   input_is_latent=True,
                                   randomize_noise=False,
                                   input_is_stylespace=args.work_in_stylespace)

            torchvision.utils.save_image(img_gen,
                                         f"results/{str(i).zfill(5)}.jpg",
                                         normalize=True,
                                         range=(-1, 1))

    if args.mode == "edit":
        final_result = torch.cat([img_orig, img_gen])
    else:
        final_result = img_gen

    return final_result
예제 #26
0
import clip
from data import PokemonDataset
from tqdm import tqdm
import numpy as np
import pandas as pd

from collections import defaultdict

save_preds = True
device = "cpu" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0005, momentum=0.5)

dataset = PokemonDataset()
classes = list(dataset.get_classes())
text = clip.tokenize(classes).to(device)

examples = dataset.fetch_per_type_examples()
for _ in tqdm(range(5)):
    optimizer.zero_grad()

    labels = []
    images = []
    for ex in examples:
        image, label = ex
        labels.append(classes.index(label))
        image_tensor = preprocess(image).unsqueeze(0).to(device)
        images.append(image_tensor)

    logits, _ = model(torch.cat(images), text)
    probs = logits.softmax(dim=-1)
예제 #27
0
                with env.begin(db=fn_db) as txn:
                    features = np.frombuffer(txn.get(key),
                                             dtype=np.float32).reshape(
                                                 (1, 512))
                print(f"Similar to {key.decode()}:")
            except:
                print("Not found.")
                continue
        elif in_text == '':
            offset = last_j
            if texts is None:
                continue
        else:
            offset = 0
            last_j = 0
            texts = clip.tokenize([in_text]).to(device)
            features = normalize(
                model.encode_text(texts).detach().cpu().numpy().astype(
                    'float32'))

        search_start = time.perf_counter()
        D, I = index.search(features, k + offset + 1)
        search_time = time.perf_counter() - search_start
        print(f"Search time: {search_time:.4f}s")
        for j, i in enumerate(I[0]):
            if j <= offset:
                continue
            with env.begin(db=idx_db) as txn:
                tfn = txn.get(f"{i}".encode()).decode()
                print(f"{D[0][j]:.4f} {i} {tfn}")
                try:
예제 #28
0
def main():
  # Reproducibility.
  torch.manual_seed(args.seed)
  random.seed(args.seed)
  np.random.seed(args.seed)

  # Obtain the utilized device.
  if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
    has_cuda = True
  else:
    device = torch.device("cpu")
    has_cuda = False

  # Setup logging.
  query = queries[args.query_idx]
  exp_name = f"{args.exp_name_prefix}_{args.query_idx:03d}_{query}"
  wandb.init(
      entity=args.wandb_entity,
      project=args.wandb_project,
      name=exp_name,
      config=args)

  # Initialize CLIP
  if args.clip_lam:
    model, preprocess, clip_size = load_clip(args.loss_model, device)
    model.eval()

  if args.retrieve_model == args.loss_model and args.clip_lam:
    test_model, test_preprocess, test_clip_size = model, preprocess, clip_size
  else:
    test_model, test_preprocess, test_clip_size = load_clip(
        args.retrieve_model, device)
    test_model.eval()

  # Initialize the volumetric model.
  volume_model = nerf.DreamFieldsMLP(
      activation="SiLU",
      features_early=[96],  # Dense layers before residual blocks.
      features_residual=[(128, 96)] * 3,  # Resid block feature dimensions.
      features_late=[96, 4],  # Features dimensions at end.
      fourfeat=args.fourfeat,
      max_deg=args.posenc_deg,
      ipe=args.ipe,
  )
  volume_model = nn.DataParallel(volume_model)
  volume_model = volume_model.to(device)
  scene_origin = scene.EMA(np.zeros(3, dtype=np.float64), decay=0.999)
  render_kwargs = dict(
      sigma_noise_std=args.sigma_noise_std,
      near=4. - math.sqrt(3) * args.volume_extent_world / 2,
      far=4. + math.sqrt(3) * args.volume_extent_world / 2,
      mask_rad=args.volume_extent_world / 2,
      n_pts_per_ray=args.n_pts_per_ray,
      device=device,
  )

  # Instantiate the Adam optimizer.
  optimizer = torch.optim.Adam(
      volume_model.parameters(), lr=args.lr_init, eps=args.adam_eps)
  scaler = torch.cuda.amp.GradScaler()

  # Embed the target caption with CLIP.
  if args.clip_lam:
    query_tok = clip.tokenize(query).to(device)
    z_clip = model.encode_text(query_tok).detach()
    z_clip = F.normalize(z_clip, dim=-1)

    clip_aug_fn = torchvision.transforms.RandomResizedCrop(
        clip_size, scale=args.crop_scale_range, ratio=(1.0, 1.0))

  if args.diffusion_lam:
    # Initialize GLIDE. Create base model.
    base_glide_model, diffusion, base_glide_options = load_diffusion(
        "base", device, has_cuda=has_cuda)
    base_glide_model.eval()

    # Embed the target caption with GLIDE.
    denoise_batch_size = (
        args.n_aug * args.n_views if args.denoise_augmented else args.n_views)
    tokens = base_glide_model.tokenizer.encode(query)
    tokens, mask = base_glide_model.tokenizer.padded_tokens_and_mask(
        tokens, base_glide_options["text_ctx"])

    # Create the classifier-free guidance tokens (empty).
    uncond_tokens, uncond_mask = base_glide_model.tokenizer.padded_tokens_and_mask(
        [], base_glide_options["text_ctx"])

    # Pack the tokens together into model kwargs.
    base_model_kwargs = dict(
        tokens=torch.tensor(
            [tokens] * denoise_batch_size +
            [uncond_tokens] * denoise_batch_size,
            device=device),
        mask=torch.tensor(
            [mask] * denoise_batch_size + [uncond_mask] * denoise_batch_size,
            dtype=torch.bool,
            device=device),
    )

    parallel_glide = nn.DataParallel(base_glide_model)

    # Create an classifier-free guidance sampling function.
    def base_model_fn(x_t, ts, **kwargs):
      half = x_t[:len(x_t) // 2]
      combined = torch.cat([half, half], dim=0)
      model_out = parallel_glide(combined, ts, **kwargs)
      eps, rest = model_out[:, :3], model_out[:, 3:]
      cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
      half_eps = uncond_eps + args.guidance_scale * (cond_eps - uncond_eps)
      eps = torch.cat([half_eps, half_eps], dim=0)
      return torch.cat([eps, rest], dim=1)

    def preprocess_glide(x, order="NHWC"):
      if order == "NHWC":
        # x is [NHWC]. Reshape to NCHW.
        x = x.movedim(-1, 1)
      x = x * 2 - 1  # Scale from [0, 1] to [-1, 1].
      if x.shape[-2:] != (64, 64):
        x = F.interpolate(x, (64, 64), mode="bilinear")
      return x

    def unprocess_glide(x):
      return (x + 1) / 2  # Scale from [-1, 1] to [0, 1].

    denoised_fn = lambda x_start: x_start
    denoise_aug_fn = torchvision.transforms.RandomResizedCrop(
        64, scale=args.crop_scale_range, ratio=(1.0, 1.0))

    glide_context_manager = (
        torch.no_grad if args.denoise_stop_grad else torch.enable_grad)

    # Initialize each chain.
    diffusion_x = torch.randn((args.n_views, 3, 64, 64),
                              device=device,
                              requires_grad=False)
    diffusion_t = torch.full(
        size=(args.n_views,),
        fill_value=args.t_respace - 1,
        requires_grad=False,
        dtype=torch.long,
        device=device)

    # Training uses n_iter iterations: 1 to n_iter (inclusive).
    # Diffusion uses t_respace timesteps: t_respace-1 to 0 (inclusive).
    # For now, check they are equal.
    # TODO(jainajay): implement sampling with non-unit timesteps.
    assert args.t_respace * args.denoise_every == args.n_iter

  # Get a batch of viewing angles and pre-generate rays.
  azimuths = np.arange(args.n_views) * 360. / args.n_views
  rads = np.full(args.n_views, 4.)
  focal_mults = np.full(args.n_views, 1.2)
  elevations = [
      scene.uniform_in_interval(args.elevation_range)
      for _ in range(args.n_views)
  ]
  cam2worlds = [
      scene.pose_spherical(azim, phi=elev, radius=rad)
      for azim, elev, rad in zip(azimuths, elevations, rads)
  ]
  height, width, focal = scene.scale_intrinsics(args.render_size)
  # Generate rays: 3-tuple of [n_views, H, W, n_pts_per_ray, 3 or 1].
  rays_all_views = scene.camera_rays_batched(cam2worlds, height, width,
                                             focal_mults * focal)

  pbar = tqdm.trange(1, args.n_iter + 1)
  for iteration in pbar:
    metrics = {}
    visualize_images = iteration % 25 == 0 or iteration == 1

    # Set learning rate
    lr = schedule.learning_rate_decay(
        iteration,
        args.lr_init,
        args.lr_final,
        args.n_iter,
        lr_delay_steps=min(args.n_iter // 8, 2500),
        lr_delay_mult=args.lr_delay_mult)
    for g in optimizer.param_groups:
      g["lr"] = float(lr)

    # Zero the optimizer gradient.
    optimizer.zero_grad()

    # Render the volumetric model from random perspectives.
    batch_idx = np.random.choice(
        args.n_views, size=args.batch_size, replace=False)
    rays_batched = [r[batch_idx] for r in rays_all_views]

    # Runs the forward pass with automatic precision casting.
    with torch.cuda.amp.autocast():
      (images, depths, disparities, silhouettes), _ = nerf.render_rays_mip(
          rays_batched,
          volume_model,
          origin=scene_origin.value,
          **render_kwargs)
      assert images.ndim == 4
      assert images.shape[0] == args.batch_size
      assert images.shape[-1] == 3

      # Transmittance loss. Anneal target opacity (1 - transmittance).
      target_opacity = schedule.anneal_logarithmically(
          iteration, args.target_transmittance_anneal_iters,
          1 - args.target_transmittance0, 1 - args.target_transmittance1)

      # The area of an object on the image plane grows with the focal length
      # and shrinks with increasing camera radius. Scale target opacity
      # proportionally with the squared focal multiplier and inversely
      # proportionally with the squared camera radius.
      target_opacities = np.minimum(
          np.ones(args.batch_size), focal_mults[batch_idx]**2 /
          (rads[batch_idx] / 4.)**2 * target_opacity)
      taus = torch.tensor(1 - target_opacities, device=device)
      avg_transmittance = 1 - silhouettes.mean(
          dim=tuple(range(1, silhouettes.ndim)))
      # NOTE(jainajay): Using a modified, two-sided transmittance loss that
      # differs from Dream Fields. It can encourage reducing transmittance if
      # the scene becomes too sparse. The original loss would penalize
      # -torch.mean(torch.min(avg_transmittance, taus)).
      transmittance_loss = torch.mean(torch.abs(avg_transmittance - taus))

      # Data augmentation.
      if (args.diffusion_lam > 0 and
          args.denoise_augmented) or args.clip_lam > 0:
        # NOTE(jainajay): this background is at the render resolution,
        #       not the resize, unlike Dream Fields.
        # Generate random backgrounds.
        bgs = augment.sample_backgrounds(
            num=args.n_aug * args.batch_size,
            res=args.render_size,
            checkerboard_nsq=args.nsq,
            min_blur_std=args.bg_blur_std_range[0],
            max_blur_std=args.bg_blur_std_range[1],
            device=device)

        # Composite renders with backgrounds.
        bgs = bgs.view(args.n_aug, args.batch_size, *bgs.shape[1:])  # ANCHW.
        bgs = bgs.movedim(2, -1)  # Convert ANCHW to ANHWC.
        composite_images = (
            silhouettes[None] * images[None] + (1 - silhouettes[None]) * bgs)
        composite_images = composite_images.reshape(  # to A*N,H,W,C.
            args.n_aug * args.batch_size, args.render_size, args.render_size, 3)
        composite_images = composite_images.movedim(3, 1)  # NHWC to NCHW.

      # Compute GLIDE loss.
      # Sample from the base model.
      if args.diffusion_lam:
        # Preprocess rendering (scale to [-1, 1]).
        if args.denoise_augmented:
          denoise_aug_images = denoise_aug_fn(composite_images)
          inp = preprocess_glide(denoise_aug_images, order="NCHW")
        else:
          inp = silhouettes * images + 1 - silhouettes  # white bg
          inp = preprocess_glide(inp, order="NHWC")

        if (iteration - 1) % args.denoise_every == 0:
          base_glide_model.del_cache()

          # Sampling step for every view in the cache.
          with glide_context_manager():
            assert diffusion_t.dtype == torch.long
            assert torch.all(diffusion_t == diffusion_t[0])
            metrics["diffusion/t"] = diffusion_t[0].item()

            xt = diffusion_x  # || x_hat(x_t) - render ||^2

            # Enable for loss: || x_hat(diffuse(render)) - x_hat(x_t) ||^2
            # x = diffusion.q_sample(
            #     inp, torch.tensor([diffusion_t] * denoise_batch_size,
            #     device=device))

            # Sample x_s from x_t using DDIM.
            # Based on glide-text2im/glide_text2im/gaussian_diffusion.py#L453
            assert args.batch_size == args.n_views  # Updating all chains.
            out = diffusion.p_mean_variance(
                base_model_fn,
                torch.cat([xt, xt], dim=0),
                torch.cat([diffusion_t, diffusion_t], dim=0),
                clip_denoised=True,
                denoised_fn=denoised_fn,  # TODO(jainajay): look into this,
                model_kwargs=base_model_kwargs,
            )
            assert out["pred_xstart"].shape[0] == 2 * args.batch_size
            pred_xstart = out["pred_xstart"][:args.batch_size]

            if iteration < args.independent_sampling_steps * args.denoise_every:
              # Ours: eps = pred_eps(x_t, t, tilde_x).
              # Ours: x_{t-1} = a * tilde_x + b * eps + sigma * noise.
              x0_for_sampling = pred_xstart
            else:
              # GLIDE: eps = pred_eps(x_t, t, x_hat(x_t)).
              # GLIDE: x_{t-1} = a * x_hat(x_t) + b * eps + sigma * noise.
              x0_for_sampling = inp.detach()

            # pylint: disable=protected-access
            eps = diffusion._predict_eps_from_xstart(diffusion_x, diffusion_t,
                                                     x0_for_sampling)
            # pylint: enable=protected-access

            assert eps.shape[0] == args.batch_size

            alpha_bar = _extract_into_tensor(diffusion.alphas_cumprod,
                                             diffusion_t, xt.shape)
            metrics["diffusion/alpha_bar"] = alpha_bar.mean().item()
            alpha_bar_prev = _extract_into_tensor(diffusion.alphas_cumprod_prev,
                                                  diffusion_t, xt.shape)
            metrics["diffusion/alpha_bar_prev"] = alpha_bar_prev.mean().item()
            sigma = (
                args.ddim_eta * torch.sqrt(
                    (1 - alpha_bar_prev) / (1 - alpha_bar)) *
                torch.sqrt(1 - alpha_bar / alpha_bar_prev))
            metrics["diffusion/sigma"] = sigma.mean().item()
            # Equation 12.
            mean_pred = (
                x0_for_sampling * torch.sqrt(alpha_bar_prev) +
                torch.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
            nonzero_mask = (
                (diffusion_t != 0).float().view(-1,
                                                *([1] * (len(xt.shape) - 1)))
            )  # No noise when t == 0.
            noise = torch.randn_like(xt)
            sample = mean_pred + nonzero_mask * sigma * noise

            # Update multiview sampling chains.
            diffusion_x_prev = diffusion_x
            diffusion_x = sample
            diffusion_t = diffusion_t - 1

            # Don't backprop through the denoiser (forces stop_grad True).
            assert args.denoise_stop_grad
            pred_xstart = pred_xstart.detach()

          base_glide_model.del_cache()

        # Loss: ||x_hat(x_t) - render||^2.
        # Slicing the predictions only optimizes a few views.
        diffusion_loss = F.mse_loss(pred_xstart[:args.n_optimize],
                                    inp[:args.n_optimize])

        # TODO(jainajay): Try other losses. Some possibilities:
        #   ||x_hat(render) - render||^2 (change L480)
        #   ||x_hat(x_t) - x_hat(diffuse(render))||^2
        #         (change denosing code to denoise render and x_t)
        #   ||eps - eps_hat(diffuse(render), eps)||^2
        #   ||eps_hat(x_t) - eps_hat(diffuse(render), eps)||^2
        #       only makes sense if that's the eps in x_t
        metrics["loss/diffusion_mse"] = diffusion_loss
      else:
        diffusion_loss = torch.tensor([0.], device=device)

      # Compute the CLIP loss.
      if args.clip_lam:
        clip_aug_images = clip_aug_fn(composite_images)
        x = preprocess(clip_aug_images)  # Resize and normalize.
        z_est = model.encode_image(x)
        z_est = F.normalize(z_est, dim=-1)
        clip_loss = -torch.sum(z_est * z_clip, dim=-1).mean()
      else:
        clip_loss = torch.tensor([0.], device=device)

      # Compute total loss and take an optimization step.
      loss = (
          args.clip_lam * clip_loss +
          args.transmittance_lam * transmittance_loss +
          args.diffusion_lam * diffusion_loss)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    if args.track_scene_origin:
      raise NotImplementedError

    # Logging.
    with torch.inference_mode():
      volume_model.eval()

      metrics["train/depths/min"] = depths.min()
      metrics["train/depths/max"] = depths.max()
      metrics["train/disparities/min"] = disparities.min()
      metrics["train/disparities/max"] = disparities.max()

      metrics.update({
          "schedule/lr": lr,
          "loss/total_loss": loss.item(),
          "loss/clip": clip_loss.item(),
          "loss/transmittance": transmittance_loss.item(),
          "train/avg_transmittance": avg_transmittance.mean().item()
      })

      # Print the current values of the losses.
      if iteration % 10 == 0:
        pbar.set_description(
            f"Iteration {iteration:05d}:" +
            f" clip_loss = {float(clip_loss.item()):1.2f}" +
            f" diffusion_loss = {float(diffusion_loss.item()):1.5f}" +
            f" avg transmittance = {float(avg_transmittance.mean().item()):1.2f}"
        )

      # Visualize the renders.
      if visualize_images:
        metrics["render/rendered"] = wandb_grid(images)
        metrics["render/silhouettes"] = wandb_grid(silhouettes)
        metrics["render/rendered_depth"] = wandb_grid(depths)

        if args.clip_lam > 0:
          metrics["render/augmented"] = wandb_grid(clip_aug_images)

        if args.diffusion_lam:
          # Show diffusion_x_prev, diffusion_x (sample), out['pred_xstart'].
          for name, val in zip(["x_t", "x_tm1", "pred_xstart"],
                               [diffusion_x_prev, diffusion_x, pred_xstart]):
            print("diffusion", name, val.shape, val.min(), val.max())
            val = unprocess_glide(val)  # [n_views, C, 64, 64]
            metrics[f"diffusion/{name}"] = wandb_grid(val)

      # Validate from a held-out view.
      if iteration % 250 == 0 or iteration == 1:
        validation_view = render_validation_view(
            volume_model,
            scene_origin,
            test_clip_size,
            args.max_validation_size,
            **render_kwargs)
        assert validation_view.ndim == 3
        assert validation_view.shape[-1] == 3
        metrics["val/render"] = wandb.Image(clamp_and_detach(validation_view))

        rank, cosine_sim = compute_query_rank(
            test_model,
            test_preprocess,
            render=validation_view.movedim(-1, 0).unsqueeze(0),
            query=query,
            queries_r=queries,
            device=device)

        metrics["val/rank"] = rank
        metrics["val/acc"] = int(rank == 0)
        metrics["val/cosine_sim"] = cosine_sim

      if iteration % 250 == 0 or iteration == 1:
        # Visualize the optimized volume by rendering from multiple viewpoints
        # that rotate around the volume's y-axis.
        video_frames = render_rotating_volume(
            volume_model,
            scene_origin=scene_origin,
            video_size=args.video_size,
            n_frames=args.video_n_frames,
            **render_kwargs)

        for name, frames in zip(["rgb", "depth", "disparity", "silhouette"],
                                video_frames):
          # frames is in THWC order.
          filename = f"/tmp/{iteration:05d}_{name}.mp4"
          if frames.shape[-1] == 1:
            media.write_video(filename, frames[Ellipsis, 0], fps=30)
          else:
            media.write_video(filename, frames, fps=30)
          print("wrote", filename,
                f"range: [{frames.min():.4f}, {frames.max():.4f}]")

          metrics[f"render/video/{name}"] = wandb.Video(
              filename, fps=30, format="mp4")

      wandb.log(metrics, iteration)

      volume_model.train()
예제 #29
0
def main():
    a = get_args()

    # Load CLIP models
    model_clip, _ = clip.load(a.model)
    if a.verbose is True: print(' using model', a.model)
    xmem = {'RN50':0.5, 'RN50x4':0.16, 'RN101':0.33}
    if 'RN' in a.model:
        a.samples = int(a.samples * xmem[a.model])
    workdir = os.path.join(a.out_dir, basename(a.in_txt))
    workdir += '-%s' % a.model if 'RN' in a.model.upper() else ''
    os.makedirs(workdir, exist_ok=True)

    if a.diverse != 0:
        a.samples = int(a.samples * 0.5)
            
    if a.transform is True:
        trform_f = transforms.transforms_custom  
        a.samples = int(a.samples * 0.95)
    else:
        trform_f = transforms.normalize()

    if a.in_txt0 is not None:
        if a.verbose is True: print(' subtract text:', basename(a.in_txt0))
        if a.translate:
            translator = Translator()
            a.in_txt0 = translator.translate(a.in_txt0, dest='en').text
            if a.verbose is True: print(' translated to:', a.in_txt0) 
        if a.multilang is True:
            model_lang = SentenceTransformer('clip-ViT-B-32-multilingual-v1').cuda()
            txt_enc0 = model_lang.encode([a.in_txt0], convert_to_tensor=True, show_progress_bar=False).detach().clone()
            del model_lang
        else:
            txt_enc0 = model_clip.encode_text(clip.tokenize(a.in_txt0).cuda()).detach().clone()

    # make init
    global params_start, params_ema
    params_shape = [1, 3, a.size[0], a.size[1]//2+1, 2]
    params_start = torch.randn(*params_shape).cuda() # random init
    params_ema = 0.
    if a.resume is not None and os.path.isfile(a.resume):
        if a.verbose is True: print(' resuming from', a.resume)
        params_start = load_params(a.resume).cuda()
        if a.keep > 0:
            params_ema = params_start[0].detach().clone()
    else:
        a.resume = 'init.pt'

    torch.save(params_start, 'init.pt') # final init
    shutil.copy(a.resume, os.path.join(workdir, '000-%s.pt' % basename(a.resume)))
    
    prev_enc = 0
    def process(txt, num):

        sd = 0.01
        if a.keep > 0: sd = a.keep + (1-a.keep) * sd
        params, image_f = fft_image([1, 3, *a.size], resume='init.pt', sd=sd, decay_power=a.decay)
        image_f = to_valid_rgb(image_f, colors = a.colors)

        if a.prog is True:
            lr1 = a.lrate * 2
            lr0 = a.lrate * 0.1
        else:
            lr0 = a.lrate
        optimizer = torch.optim.Adam(params, lr0)
    
        if a.verbose is True: print(' ref text: ', txt)
        if a.translate:
            translator = Translator()
            txt = translator.translate(txt, dest='en').text
            if a.verbose is True: print(' translated to:', txt)
        if a.multilang is True:
            model_lang = SentenceTransformer('clip-ViT-B-32-multilingual-v1').cuda()
            txt_enc = model_lang.encode([txt], convert_to_tensor=True, show_progress_bar=False).detach().clone()
            del model_lang
        else:
            txt_enc = model_clip.encode_text(clip.tokenize(txt).cuda()).detach().clone()
        if a.notext > 0:
            txt_plot = torch.from_numpy(plot_text(txt, a.modsize)/255.).unsqueeze(0).permute(0,3,1,2).cuda()
            txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone()
        else: txt_plot_enc = None

        out_name = '%03d-%s' % (num+1, txt_clean(txt))
        out_name += '-%s' % a.model if 'RN' in a.model.upper() else ''
        tempdir = os.path.join(workdir, out_name)
        os.makedirs(tempdir, exist_ok=True)
        
        pbar = ProgressBar(a.steps // a.fstep)
        for i in range(a.steps):
            loss = 0

            noise = a.noise * torch.randn(1, 1, *params[0].shape[2:4], 1).cuda() if a.noise > 0 else None
            img_out = image_f(noise)
            
            if a.sharp != 0:
                lx = torch.mean(torch.abs(img_out[0,:,:,1:] - img_out[0,:,:,:-1]))
                ly = torch.mean(torch.abs(img_out[0,:,1:,:] - img_out[0,:,:-1,:]))
                loss -= a.sharp * (ly+lx)

            imgs_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, micro=1.)
            out_enc = model_clip.encode_image(imgs_sliced[-1])
            loss -= torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean()
            if a.notext > 0:
                loss += a.notext * torch.cosine_similarity(txt_plot_enc, out_enc, dim=-1).mean()
            if a.diverse != 0:
                imgs_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, micro=1.)
                out_enc2 = model_clip.encode_image(imgs_sliced[-1])
                loss += a.diverse * torch.cosine_similarity(out_enc, out_enc2, dim=-1).mean()
                del out_enc2; torch.cuda.empty_cache()
            if a.expand > 0:
                global prev_enc
                if i > 0:
                    loss += a.expand * torch.cosine_similarity(out_enc, prev_enc, dim=-1).mean()
                prev_enc = out_enc.detach().clone()
            if a.in_txt0 is not None: # subtract text
                loss += torch.cosine_similarity(txt_enc0, out_enc, dim=-1).mean()
            del img_out, imgs_sliced, out_enc; torch.cuda.empty_cache()

            if a.prog is True:
                lr_cur = lr0 + (i / a.steps) * (lr1 - lr0)
                for g in optimizer.param_groups: 
                    g['lr'] = lr_cur
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % a.fstep == 0:
                with torch.no_grad():
                    img = image_f(contrast=a.contrast).cpu().numpy()[0]
                if a.sharp != 0:
                    img = img **1.3 # empirical tone mapping
                checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.fstep)), verbose=a.verbose)
                pbar.upd()
                del img

        if a.keep > 0:
            global params_start, params_ema
            params_ema = ema(params_ema, params[0].detach().clone(), num+1)
            torch.save((1-a.keep) * params_start + a.keep * params_ema, 'init.pt')
        
        torch.save(params[0], '%s.pt' % os.path.join(workdir, out_name))
        shutil.copy(img_list(tempdir)[-1], os.path.join(workdir, '%s-%d.jpg' % (out_name, a.steps)))
        os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(workdir, out_name)))

    with open(a.in_txt, 'r', encoding="utf-8") as f:
        texts = f.readlines()
        texts = [tt.strip() for tt in texts if len(tt.strip()) > 0 and tt[0] != '#']
    if a.verbose is True: 
        print(' total lines:', len(texts))
        print(' samples:', a.samples)

    for i, txt in enumerate(texts):
        process(txt, i)

    vsteps = int(a.length * 25 / len(texts)) # 25 fps
    tempdir = os.path.join(workdir, '_final')
    os.makedirs(tempdir, exist_ok=True)
    
    def read_pt(file):
        return torch.load(file).cuda()

    if a.verbose is True: print(' rendering complete piece')
    ptfiles = file_list(workdir, 'pt')
    pbar = ProgressBar(vsteps * len(ptfiles))
    for px in range(len(ptfiles)):
        params1 = read_pt(ptfiles[px])
        params2 = read_pt(ptfiles[(px+1) % len(ptfiles)])

        params, image_f = fft_image([1, 3, *a.size], resume=params1, sd=1., decay_power=a.decay)
        image_f = to_valid_rgb(image_f, colors = a.colors)

        for i in range(vsteps):
            with torch.no_grad():
                img = image_f((params2 - params1) * math.sin(1.5708 * i/vsteps)**2)[0].permute(1,2,0)
                img = torch.clip(img*255, 0, 255).cpu().numpy().astype(np.uint8)
            imsave(os.path.join(tempdir, '%05d.jpg' % (px * vsteps + i)), img)
            if a.verbose is True: cvshow(img)
            pbar.upd()

    os.system('ffmpeg -v warning -y -i %s\%%05d.jpg "%s.mp4"' % (tempdir, os.path.join(a.out_dir, basename(a.in_txt))))
    if a.keep > 0: os.remove('init.pt')
예제 #30
0
def main(args):

    text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
    os.makedirs(args.results_dir, exist_ok=True)

    F = PerceptualModel(min_val=-1.0, max_val=1.0)

    g_ema = Generator(1024, 512, 8)
    g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
    g_ema.eval()
    g_ema = g_ema.cuda()
    z_mean = g_ema.mean_latent(4096)
    # z_load = np.load(args.latent_path)
    # z_init = torch.from_numpy(z_load).cuda()
    # print(np.shape(latent_load))
    F_OOM = args.f_oom

    if args.mode == "man":
        z_init = torch.load(args.latent_path).cuda()
    else:
        z_init_not_trunc = torch.randn(1, 512).cuda()
        with torch.no_grad():
            _, z_init = g_ema([z_init_not_trunc],
                              truncation_latent=z_mean,
                              return_latents=True,
                              truncation=0.7)

    x, _ = g_ema([z_init], input_is_latent=True, randomize_noise=False)

    # z = z_init.detach().clone()
    z = z_mean.detach().clone().repeat(1, 18, 1)

    z.requires_grad = True

    clip_loss = CLIPLoss()

    optimizer = optim.Adam([z], lr=args.lr)

    pbar = tqdm(range(args.step))

    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]["lr"] = lr

        x_rec, _ = g_ema([z], input_is_latent=True, randomize_noise=False)
        if not F_OOM:
            loss = 0.0
            # Reconstruction loss.
            loss_pix = torch.mean((x - x_rec)**2)
            loss = loss + loss_pix * args.loss_pix_weight
            log_message = f'loss_pix: {_get_tensor_value(loss_pix):.3f}'

            # Perceptual loss.
            if args.loss_feat_weight:
                x_feat = F.net(x)
                x_rec_feat = F.net(x_rec)
                loss_feat = torch.mean((x_feat - x_rec_feat)**2)
                loss = loss + loss_feat * args.loss_feat_weight
                log_message += f', loss_feat: {_get_tensor_value(loss_feat):.3f}'

            # Regularization loss.
            if args.loss_reg_weight:
                loss_reg = torch.mean((z_init - z)**2)
                # loss_reg = ((z_init - z) ** 2).sum()
                loss = loss + loss_reg * args.loss_reg_weight
                log_message += f', loss_reg: {_get_tensor_value(loss_reg):.3f}'

            # CLIP loss.
            if args.loss_clip_weight:
                loss_clip = clip_loss(x_rec, text_inputs)
                loss = loss + loss_clip[0][0] * args.loss_clip_weight
                log_message += f', loss_clip: {_get_tensor_value(loss_clip[0][0]):.3f}'
        else:
            loss_reg = ((z_init - z)**2).sum()
            loss_clip = clip_loss(x_rec, text_inputs)
            loss = loss_reg + loss_clip[0][
                0] * args.loss_clip_weight  # set loss_clip_weight as 200 in my case.

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description((f"loss: {loss.item():.4f};"))

    final_result = torch.cat([x, x_rec])
    return final_result