예제 #1
0
 def log_result(self, variables, step_iter):
     if hasattr(self, 'bm'):
         res = self.benchmark(variables, self.out)
         self.losses.append([step_iter, res])
     else:
         res = {'loss': np.array(self.loss)}
     self.outs.append(to_image(make_grid(self.out), cv2_format=False))
     return
예제 #2
0
def save_result(save_dir, fn, collages, target, weight, out, vars, losses,
                t_outs=None, t_out=None, t_target=None, transform=None):
    jpg_quality = [int(cv2.IMWRITE_JPEG_QUALITY), 100]

    idx = np.argmin(losses[-1][1]['vgg'])
    make_video(osp.join(save_dir, '{}.mp4'.format(fn)), collages, duration=5)
    cv2.imwrite(osp.join(save_dir, '{}.target.jpg'.format(fn)),
                to_image(target)[0], jpg_quality)
    cv2.imwrite(osp.join(save_dir, '{}.weight.jpg'.format(fn)),
                to_image(weight)[0], jpg_quality)

    if t_out is not None:
        cv2.imwrite(osp.join(save_dir, '{}.transform.final.jpg'.format(fn)),
                    to_image(out)[idx], jpg_quality)
        cv2.imwrite(osp.join(save_dir, '{}.final.jpg'.format(fn)),
                    to_image(t_out)[idx], jpg_quality)
    else:
        cv2.imwrite(osp.join(save_dir, '{}.final.jpg'.format(fn)),
                    to_image(out)[idx], jpg_quality)

    if t_target is not None:
        cv2.imwrite(osp.join(save_dir, '{}.transform.target.jpg'.format(fn)),
                    to_image(t_target)[0], jpg_quality)

    # Keep an uncompressed version of the target and weight
    np.save(osp.join(save_dir, '{}.loss.npy'.format(fn)), np.array(losses))

    if t_outs is not None:
        make_video(osp.join(save_dir, '{}.transform.mp4'.format(fn)),
                   t_outs[0], duration=5)
        make_video(osp.join(save_dir, '{}.transform.out.mp4'.format(fn)),
                   t_outs[1], duration=5)

    exp_vars = {'vars': vars, 'transform': transform}
    np.save(osp.join(save_dir, '{}.vars.npy'.format(fn)), exp_vars)
    return
예제 #3
0
def search_transform(model,
                     transform_fn,
                     var_manager,
                     loss_fn,
                     meta_steps=30,
                     grad_steps=30,
                     z_sigma=0.5,
                     t_sigma=0.05,
                     log=False,
                     pbar=None):
    """
    Searches for transformation parameter to apply to the target image such
    that it is better invertible. The transformation is optimized in a
    BasinCMA like fashion. Outer loop optimizes transformation using CMA and
    the inner loop is optimized using gradient descent.

    TODO: Clean up the code and integrate with optimzers.py

    Args:
        model:
            generative model
        transform_fn:
            the transformation search function
        var_manager:
            variable manager
        loss_fn:
            loss function to optimize with
        meta_steps:
            Number of CMA updates for transformation parameter
        grad_steps:
            Number of ADAM updates for z, c
        z_sigma:
            Sigma for latent variable resampling
        t_sigma:
            Sigma for transformation parameter
        log:
            Returns intermediate transformation result if True
        pbar:
            Progress bar such TQDM or st.progress()
    Returns
        Transformation parameters
    """

    # -- setup CMA -- #
    var_manager.optimize_t = True
    t_outs, outs, step_iter = [], [], 0
    total_steps = meta_steps * grad_steps
    t_cma_opt = CMA(mu=var_manager.t.cpu().numpy()[0], sigma=t_sigma)

    if t_cma_opt.batch_size() > var_manager.num_seeds:
        import nevergrad as ng
        print('Number of seeds is less than that required by PyCMA ' +
              'transformation search. Using Nevergrad CMA instead.')
        batch_size = var_manager.num_seeds
        opt_fn = ng.optimizers.registry['CMA']
        p = ng.p.Array(init=var_manager.t.cpu().numpy()[0])
        p = p.set_mutation(sigma=t_sigma)
        t_cma_opt = opt_fn(parametrization=p, budget=meta_steps)
        variables = var_manager.init(var_manager.num_seeds)
        using_nevergrad = True
    else:
        batch_size = t_cma_opt.batch_size()
        variables = var_manager.init(batch_size)
        using_nevergrad = False

    # Note: we will use the binarized weight for the CMA loss.
    mask = binarize(variables.weight.data[0]).unsqueeze(0)
    target = variables.target.data[0].unsqueeze(0)

    for i in range(meta_steps):

        # -- initialize variable -- #
        if using_nevergrad:
            _t = [t_cma_opt.ask() for _ in range(var_manager.num_seeds)]
            t = np.concatenate([np.array(x.args) for x in _t])
        else:
            t = t_cma_opt.ask()

        t_params = torch.Tensor(t)
        z = propagate_z(variables, z_sigma)
        variables = var_manager.init(num_seeds=batch_size,
                                     z=z.detach().float().cuda(),
                                     t=t_params.detach().float().cuda())

        for x in variables.t.data:
            x.requires_grad = False

        # -- inner update -- #
        losses = []
        for j in range(grad_steps):

            out, _, other = step(model,
                                 variables,
                                 loss_fn=loss_fn,
                                 transform_fn=transform_fn,
                                 optimize=True)

            step_iter += 1

            # Compute loss after inverting
            t_params = torch.stack(variables.t.data)
            inv_out = transform_fn(out, t_params, invert=True)

            loss = loss_fn(inv_out, target.repeat(inv_out.size(0), 1, 1, 1),
                           mask.repeat(inv_out.size(0), 1, 1, 1))

            losses.append(to_numpy(loss))
            outs.append(to_image(make_grid(out), cv2_format=False))

            if pbar is not None:
                pbar.progress(step_iter / total_steps)
            else:
                if (step_iter + 1) % 50 == 0:
                    progress_print('transform', step_iter + 1, total_steps,
                                   'c')

        # -- update CMA -- #
        if using_nevergrad:
            for z, l in zip(_t, np.min(losses, 0)):
                t_cma_opt.tell(z, l)
        else:
            t_cma_opt.tell(t, np.min(losses, 0))

        # -- log for visualization -- #
        if log and 'target' in other.keys():
            t_out = binarize(other['weight'], 0.3) * other['target']
            t_outs.append(to_image(make_grid(t_out), cv2_format=False))

    if using_nevergrad:
        t_mu = np.array(t_cma_opt.provide_recommendation().value)
    else:
        t_mu = t_cma_opt.mean()
    return torch.Tensor([t_mu]), t_outs, outs
예제 #4
0
    def auto_detect(self, im, mask_type='segmentation'):
        if not hasattr(self, 'detector'):
            self.detector = Detector()

        if not hasattr(self, 'classifier'):
            self.classifier = Classifier()

        candidates = self.detector(im, is_tensor=False)

        if candidates is None:
            print('Did not find any valid object in the image.')

        else:
            det_bboxes = candidates['boxes']
            det_labels = candidates['labels']
            det_scores = candidates['scores']
            det_masks = candidates['masks']

            coco_to_wnid = imagenet_tools.get_coco_valid_wnids()

            # Start from highest to lowest score
            for idx in np.argsort(det_scores.cpu().numpy())[::-1]:
                det_cls_noun = COCO_INSTANCE_CATEGORY_NAMES[det_labels[idx]]
                bbox = det_bboxes[idx].cpu().numpy().astype(np.int)

                bbox_im = im[bbox[1]:bbox[3], bbox[0]:bbox[2], :]

                top5_cls = self.classifier(bbox_im, is_tensor=False, top5=True)
                pred_cls = top5_cls[1][0].item()

                misc = []
                for c in top5_cls[1]:
                    misc.append([c.item(), IMAGENET_LABEL_TO_NOUN[c.item()]])

                pred_wnid = IMAGENET_LABEL_TO_WNID[pred_cls]
                pred_cls_noun = IMAGENET_LABEL_TO_NOUN[pred_cls]

                valid_wnids = coco_to_wnid[det_cls_noun]

                if pred_wnid in valid_wnids:
                    print(('Found a match. Classified class {} is in the ' +
                           'detected class {}').format(pred_cls_noun,
                                                       det_cls_noun))

                    if mask_type == 'segmentation':
                        m = det_masks[idx] > 0.5

                    elif mask_type == 'bbox':
                        m = torch.zeros(1, im.shape[0], im.shape[1])
                        m[:, bbox[1]:bbox[3], bbox[0]:bbox[2]] = 1.0

                    else:
                        invalid_msg = 'Invalid mask_type {}'
                        raise ValueError(invalid_msg.format(mask_type))

                    self._cls_lbl = pred_cls
                    m = to_image(m, denormalize=False)
                    m = cv2.medianBlur(m.astype(np.uint8), 5)
                    self._mask = m.reshape(m.shape[0], m.shape[1], 1) / 255.
                    return self._mask, pred_cls_noun, det_cls_noun, misc

                print(('Classification and Detection is inconsistent. ' +
                       'Classified class {} is not an element of the ' +
                       'detected class {}. Trying next candidate').format(
                           pred_cls_noun, det_cls_noun))

        print('Auto-detection failed. All candidates are invalid.')

        cls_lbl = self.classifier(im, as_onehot=False, is_tensor=False)
        self._cls_lbl = cls_lbl
        print('Mask is set to None and the predicted class is: {} ({})'.format(
            cls_lbl, IMAGENET_LABEL_TO_NOUN[cls_lbl]))
        return
예제 #5
0
    else:
        if not hasattr(solver, '_cls_lbl'):
            solver.auto_detect(im, mask_type='bbox')
        cls_lbl = solver._cls_lbl

    variables, outs, losses, transform_fn = \
        solver(im, mask=mask, cls_lbl=cls_lbl,
               encoder_init=args.encoder_init,
               num_seeds=args.num_seeds,
               max_batch_size=args.max_batch_size,
               log=False)

    idx = np.argmin(losses).squeeze()
    z = torch.stack(variables.z.data)
    cv = torch.stack(variables.cv.data)

    with torch.no_grad():
        out = solver.model(z=z[idx:idx + 1], c=cv[idx:idx + 1])
        out_im = to_image(out)[0]

    t = torch.stack(variables.t.data)[idx:idx + 1]

    inv_im = to_image(transform_fn(out.cpu(), t.cpu(), invert=True))[0]
    blended = poisson_blend(im[:, :, [2, 1, 0]], mask, inv_im)

    os.makedirs('./results', exist_ok=True)

    jpg_quality = [int(cv2.IMWRITE_JPEG_QUALITY), 100]
    cv2.imwrite('./results/projected-{}.jpg'.format(fn), inv_im, jpg_quality)
    cv2.imwrite('./results/blended-{}.jpg'.format(fn), blended, jpg_quality)
def main():
    st.sidebar.markdown('<h1> pix2latent </h1>',
                        unsafe_allow_html=True)
    st.sidebar.markdown('<p> Interactive demo for: <br><b>Transforming and ' +
                        'Projecting Images to Class-conditional Generative' +
                        ' Networks</b></p>',
                        unsafe_allow_html=True)

    load_option = st.sidebar.selectbox('How do you want to load your image?',
                                       ('upload', 'search'))

    if load_option == 'upload':
        im = upload_file()
    else:
        im = file_selector('./examples')

    if im is None:
        return

    if im.shape != (256, 256, 3):
        st.sidebar.text(
            '(Warning): Image is in the incorrect dimension, ' +
            'automatically cropping and resizing image')
        im = center_crop(im)
        im = smart_resize(im)

    # Go through all detection
    st.sidebar.markdown('<center><h3>Selected image</h3></center>',
                        unsafe_allow_html=True)
    im_view = st.sidebar.radio(
        'masking method', ('bbox', 'segmentation'))

    mask_overlay_im, mask, p_noun, d_noun, misc = detect(im, im_view)

    st.sidebar.image(mask_overlay_im)
    st.sidebar.markdown('<b>Detected class</b>: {}'.format(d_noun),
                        unsafe_allow_html=True)
    st.sidebar.markdown('<b>Predicted class</b>: {}'.format(p_noun),
                        unsafe_allow_html=True)

    misc_nouns = np.array([x[1] for x in misc])
    selected = st.sidebar.selectbox('Change class', misc_nouns)
    selected_cls = misc[np.argwhere(misc_nouns == selected).squeeze()][0]

    # Optimization config
    st.sidebar.markdown('<center> <h3> Optimization config </h3> </center>',
                        unsafe_allow_html=True)

    num_seeds = st.sidebar.slider(
        'Number of seeds',
        min_value=1,
        max_value=18,
        value=18,
        step=1,
    )

    if num_seeds != 18:
        st.sidebar.text('(Warning): PyCMA num_seeds is fixed to 18. ' +
                        'Using Nevergrad implementation instead. May ' +
                        'not work as well.')

    max_batch_size = st.sidebar.slider(
        'Max batch size',
        min_value=1,
        max_value=18,
        value=9,
        step=1,
    )

    cma_steps = st.sidebar.slider(
        'CMA update',
        min_value=1,
        max_value=50,
        value=30,
        step=5,
    )

    adam_steps = st.sidebar.slider(
        'ADAM update',
        min_value=1,
        max_value=50,
        value=30,
        step=5,
    )

    ft_adam_steps = st.sidebar.slider(
        'Final ADAM update',
        min_value=1,
        max_value=1000,
        value=300,
        step=50,
    )

    transform_cma_steps = st.sidebar.slider(
        'Transform CMA update',
        min_value=1,
        max_value=50,
        value=30,
        step=5,
    )

    transform_adam_steps = st.sidebar.slider(
        'Transform ADAM update',
        min_value=1,
        max_value=50,
        value=30,
        step=5,
    )

    encoder_init = st.sidebar.checkbox(
        'Encoder init',
        value=True,
    )

    start_optimization = st.sidebar.button('Optimize')

    if not start_optimization:
        return

    variables, outs, losses, transform_fn = \
        run_optimization(im, selected_cls, num_seeds, max_batch_size,
                         cma_steps, adam_steps, ft_adam_steps,
                         transform_cma_steps, transform_adam_steps,
                         encoder_init)

    # Collage
    collage_results = to_image(make_grid(outs), cv2_format=False)

    # Blended collage
    t = torch.stack(variables.t.data)[:1]
    inv_ims = []
    for out in outs:
        inv_ims.append(
            transform_fn(out.unsqueeze(0).cpu(), t.cpu(), invert=True))

    inv_collage_results = to_image(make_grid(torch.cat(inv_ims)), cv2_format=False)

    # Show Results
    blended = []

    for x in inv_ims:
        inv_im = to_image(x, cv2_format=False)[0]
        b = poisson_blend(im, mask, inv_im)
        blended.append(to_tensor(b))

    blended = torch.cat(blended)
    blended_collage_results = to_image(make_grid(blended), cv2_format=False)

    st.markdown('<h3> Projection </h3>', unsafe_allow_html=True)
    st.image(collage_results, use_column_width=True)
    st.markdown('<h3> Inverted </h3>', unsafe_allow_html=True)
    st.image(inv_collage_results, use_column_width=True)
    st.markdown('<h3> Poisson blended </h3>', unsafe_allow_html=True)
    st.image(blended_collage_results, use_column_width=True)
    return