Beispiel #1
0
    def __init__(self, idx: Idx, prev_idx: Idx, zero:Coords, crop_scales:Tuple[int,int], extractor:Callable, prediction: torch.Tensor):
        gc_pred, nb_pred, coords_pred = extractor(prediction)

        xc_pred, yc_pred, xn_pred, yn_pred = [u.numpify(coords_pred[i][0]) for i in range(4)]

        self.prev_idx = prev_idx
        self.idx = idx

        self.confidence = u.numpify(gc_pred[0,0])
        self.nb_confidences = u.numpify(torch.sigmoid(nb_pred[0]))
        self.color = u.numpify(torch.sigmoid(gc_pred[0,1]))
        self.neighs = np.stack([xn_pred, yn_pred]).transpose() * crop_scales + zero
        self.corners = np.stack([xc_pred, yc_pred]).transpose() * crop_scales + zero
Beispiel #2
0
    def reorder_edges(self) -> 'Graph':
        """Reorder edges according to weight matrix entries
        Needed to be consistent with pyGSP `get_edge_list()` output.

        Returns:
            Graph:
        """
        # neet to have senders sorted, then receiver by senders sorted
        # use numpy stable sorting (mergesort for second sort to keep receivers ordered)
        senders = numpify(self.senders)
        receivers = numpify(self.receivers)

        indices = np.argsort(receivers)
        next_indices = np.argsort(senders[indices], kind='mergesort')
        new_indices = indices[next_indices]
        new_indices = torch.tensor(new_indices, device=self.device)

        return self.update(senders=self.senders[new_indices],
                           receivers=self.receivers[new_indices],
                           edges=self.edges[new_indices])
Beispiel #3
0
    def pairwise_distances(self) -> torch.Tensor:
        """ Compute pairwise distance (number of edges) between all pair of nodes
        Uses NetworkX shortest path algorithm, very slow for big graph.
        Caches the results.

        Returns:
            torch.Tensor: [n_node, n_node]
        """

        if self._distances is None:
            G = nx.DiGraph()
            G.add_edges_from(
                zip(numpify(self.senders), numpify(self.receivers)))
            G.add_nodes_from(range(self.n_node))
            self._distances = torch.zeros([self.n_node, self.n_node],
                                          device=self.device) - 1
            for source, targets in nx.shortest_path_length(G):
                for target, length in targets.items():
                    self._distances[source, target] = length

        return self._distances
Beispiel #4
0
 def sample_qz_uncertainty(self, lmbda, yq):
     """image-space pixel-wise uncertainty estimation via MC sampling"""
     B, K, _ = yq.size()
     mu_z, logvar_z = lmbda.chunk(2, dim=1)
     z = dist.Normal(mu_z, to_sigma(logvar_z)).rsample([self.config.num_mc_samples])
     var_x = []
     # var_obj = []
     for s in range(self.config.num_mc_samples):
         z_yq = self.projector(torch.cat((z[s], yq.reshape(B * K, -1)), dim=-1))
         mask_logits, mu_x = self.decode(z_yq)
         var_x.append(utils.numpify(torch.sum(torch.softmax(mask_logits, dim=1) * mu_x, dim=1)))  # [B, 3, H, w]
         # var_obj.append(utils.numpify(torch.sigmoid(mask_logits) * mu_x))  # [B, K, 3, H, w]
     var_x = np.stack(var_x, axis=0).var(axis=0, ddof=1).sum(1)
     # var_obj = np.stack(var_obj, axis=0).var(axis=0, ddof=1).sum(2)
     return var_x, None    # var_obj
Beispiel #5
0
    def v_travel(self, lmbda, v_pts, save_sample_to=None, save_start_id=0):
        """
        Viewpoint queired predictions along a viewpoint trajectory.

        :param z: [B*K, D]
        :param v_pts: [B, L, dview]   (a viewpoint trajectory)
        """
        save_dir = os.path.join(save_sample_to, 'v_track')
        utils.ensure_dir(save_dir)
        B, L, _ = v_pts.size()
        K = lmbda.size(0) // B

        mu_z, logvar_z = lmbda.chunk(2, dim=-1)
        z = dist.Normal(mu_z, to_sigma(logvar_z)).rsample()

        v_feat = self.view_encoder(v_pts.reshape(B * L, -1))  # output [B*V, 8]
        v_feat = v_feat.reshape(B, L, -1).unsqueeze(1).repeat(1, K, 1, 1)

        GIFs = {
            'alpha': [],  # RGB images
            'seg': [],  # Segmentation
            'uncer': []  # Unvertainty (pixel-wise space)
        }
        for l in range(L):
            yq = v_feat[:, :, l, :]
            z_yq = self.projector(torch.cat((z, yq.reshape(B * K, -1)),
                                            dim=-1))
            mask_logits, mu_x = self.decode(z_yq)
            masks = torch.softmax(mask_logits, dim=1)
            x_hat = torch.sum(masks * mu_x, dim=1)
            uncer, _ = self.sample_qz_uncertainty(lmbda, yq)
            GIFs['alpha'].append(x_hat)
            GIFs['seg'].append(masks.squeeze(2))
            GIFs['uncer'].append(torch.from_numpy(uncer).to(masks).float())

        for key in GIFs.keys():
            GIFs[key] = torch.stack(GIFs[key], dim=0)  # [steps, B, #, C, H, W]

        for b in range(B):
            save_batch_dir = os.path.join(save_dir, str(b + save_start_id))
            utils.ensure_dir(save_batch_dir)
            for key in GIFs.keys():
                prefix = '{}{}'.format(b + save_start_id, key)
                for iid in range(L):
                    if key == 'alpha':
                        vis.enhance_save_single_image(
                            utils.numpify(GIFs[key][iid, b, ...].cpu().permute(
                                1, 2, 0)),
                            os.path.join(save_batch_dir,
                                         '{}_{:02d}.png'.format(prefix, iid)))
                        # save_image(tensor=GIFs[key][iid, b, ...].cpu(),
                        #            filename=os.path.join(save_batch_dir, '{}_{:02d}.jpg'.format(prefix, iid)))
                    elif key == 'seg':
                        seg = np.argmax(utils.numpify(GIFs[key][iid, b,
                                                                ...].cpu()),
                                        axis=0).astype('uint8')
                        seg = vis.save_dorder_plots(seg, K_comps=K, cmap='hsv')
                        vis.save_single_image(
                            seg,
                            os.path.join(save_batch_dir,
                                         '{}_{:02d}.png'.format(prefix, iid)))
                    elif key == 'uncer':
                        vis_var = np.log10(
                            utils.numpify(GIFs[key][iid, b, ...].cpu()) + 1e-6)
                        vis_var = vis.map_val_colors(vis_var,
                                                     v_min=-6.,
                                                     v_max=-2.,
                                                     cmap='hot')
                        vis.save_single_image(
                            vis_var,
                            os.path.join(save_batch_dir,
                                         '{}_{:02d}.png'.format(prefix, iid)))
                    else:
                        raise NotImplementedError
                vis.grid2gif(
                    str(os.path.join(save_batch_dir, prefix + '*.png')),
                    str(os.path.join(save_batch_dir, '{}.gif'.format(prefix))),
                    delay=20)
Beispiel #6
0
    def predict(self,
                images,
                targets,
                save_sample_to=None,
                save_start_id=0,
                vis_train=True,
                vis_uncertainty=False):
        """
        We show uncertainty
        """
        xmul = torch.stack(images, dim=0)  # [B, V, C, H, W]
        v_pts = torch.stack([tar['view_points'] for tar in targets],
                            dim=0).type(xmul.dtype)  # [B, V, 3]

        B, V, _, _, _ = xmul.size()
        K, nit_inner_loop, z_dim = self.K, self.nit_innerloop, self.z_dim

        # sample the number of observations and which observations
        obs_view_idx, qry_view_idx = self.sample_view_config(
            V, self.num_vq_show, self.num_vq_show, allow_repeat=False)

        # Initialize parameters for latents' distribution
        assert not torch.isnan(self.lmbda0).any().item(), 'lmbda0 has nan'
        lmbda = self.lmbda0.expand((B * K, ) + self.lmbda0.shape[1:])

        # --- get view codes ---
        v_feat = self.view_encoder(v_pts.reshape(B * V, -1))  # output [B*V, 8]
        v_feat = v_feat.reshape(B, V, -1).unsqueeze(1).repeat(1, K, 1, 1)

        # --- record for visualisation --- #
        vis_images = []
        vis_recons = []
        vis_comps = []
        vis_hiers = []
        vis_2d_latents = []
        vis_3d_latents = []

        # --- scene learning phase --- #
        for venum, v in enumerate(obs_view_idx):
            x = xmul[:, v, ...]
            y = v_feat[:, :, v, :]

            # Knowledge summarize in an iterative fashion (does not have to be though: set T=1)
            nelbo_v, lmbda, m_logits, mu_x = self._iterative_inference(
                x, y, lmbda, nit_inner_loop)
            masks = torch.softmax(m_logits, dim=1)

            # get independent object silhouette
            indi_masks = torch.sigmoid(m_logits)

            vis_images.append(utils.numpify(x))
            vis_recons.append(utils.numpify(torch.sum(masks * mu_x, dim=1)))
            vis_comps.append(utils.numpify(indi_masks * mu_x))
            vis_hiers.append(utils.numpify(masks))

            del mu_x, m_logits, masks

        # --- scene querying phase --- #
        assert len(qry_view_idx) > 0
        for vqnum, vq in enumerate(qry_view_idx):
            x = xmul[:, vq, ...]
            yq = v_feat[:, :, vq, :]

            # making view-dependent generation
            mu_z, logvar_z = lmbda.chunk(2, dim=1)
            z = dist.Normal(mu_z, to_sigma(logvar_z)).rsample()

            z_yq = self.projector(torch.cat((z, yq.reshape(B * K, -1)),
                                            dim=-1))

            mask_logits, mu_x = self.decode(z_yq)
            # get masks
            masks = torch.softmax(mask_logits, dim=1)

            # get independent object silhouette
            indi_masks = torch.sigmoid(mask_logits)
            # uncomment the below to binarize the silhouette with a tunable threshold (default: 0.5).
            # indi_masks = (indi_masks > 0.5).type(mu_x.dtype)

            vis_images.append(utils.numpify(x))
            vis_recons.append(utils.numpify(torch.sum(masks * mu_x, dim=1)))
            vis_comps.append(utils.numpify(indi_masks * mu_x))
            vis_hiers.append(utils.numpify(masks))
            vis_2d_latents.append(utils.numpify(z_yq.reshape(B, K, -1)))
            vis_3d_latents.append(utils.numpify(z.reshape(B, K, -1)))

            del mu_x, mask_logits, masks

        vis_images = np.stack(vis_images, axis=1)  # [B, V, 3, H, W]
        vis_recons = np.stack(vis_recons, axis=1)  # [B, V, 3, H, W]
        vis_comps = np.stack(vis_comps, axis=1)  # [B, V, K, 3, H, W]
        vis_hiers = np.squeeze(np.stack(vis_hiers, axis=1),
                               axis=3)  # [B, V, K, H, W]
        vis_2d_latents = np.stack(vis_2d_latents, axis=1)  # [B, qV, K, D]
        vis_3d_latents = np.stack(vis_3d_latents, axis=1)  # [B, qV, K, D]

        if save_sample_to is not None:
            if vis_train:
                self.save_visuals(vis_images,
                                  vis_recons,
                                  vis_comps,
                                  vis_hiers,
                                  save_dir=save_sample_to,
                                  start_id=save_start_id)
            else:
                self.save_visuals_eval(len(obs_view_idx),
                                       vis_images,
                                       vis_recons,
                                       vis_comps,
                                       vis_hiers,
                                       save_dir=save_sample_to,
                                       start_id=save_start_id)
        preds = {
            'x_recon': vis_recons,
            'x_comps': vis_comps,
            'hiers': vis_hiers,
            'lmbda': lmbda,
            '2d_latents': vis_2d_latents,
            '3d_latents': vis_3d_latents,
            'scene_indices': list(tar['scn_id'][0].item() for tar in targets),
            'obs_views': obs_view_idx,
            'query_views': qry_view_idx
        }
        return preds
Beispiel #7
0
    def plot_trajectory(self,
                        distributions: torch.Tensor,
                        colors: list,
                        with_edge_arrows: bool = False,
                        highlight: Union[int, List[int]] = None,
                        zoomed: bool = False,
                        ax=None,
                        normalize_intercept: bool = False,
                        edge_width: float = .1):
        """Plot a trajectory on this graph

        Args:
            distributions (torch.Tensor): [n_observations, n_node] sequence of probability distribution to plot
            colors (list): [n_observations, ] color per observation
            with_edge_arrows (bool, optional): Defaults to False. Show strongest edge direction arrow at each node
            highlight (Union[int, List[int]], optional): Defaults to None. Some node id(s) to highlight
            zoomed (bool, optional): Defaults to False. Zoom only onto the interesting part of the distributions (not too small)
            ax (optional): Defaults to None. Matplotlib axis
            normalize_intercept (bool, optional): Defaults to False. PyGSP plotting normalize intercept for widths
            edge_width (float, optional): Defaults to .1. edge width

        Returns:
            fig, ax from matplotlib
        """

        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)

        if zoomed:
            display_points_mask = distributions.sum(dim=0) > 1e-4
            display_coords = self.coords[display_points_mask]
            xmin, xmax = display_coords[:, 0].min(), display_coords[:, 0].max()
            ymin, ymax = display_coords[:, 1].min(), display_coords[:, 1].max()
            xcenter, ycenter = (xmax + xmin) / 2, (ymax + ymin) / 2
            size = max(xmax - xmin, ymax - ymin)
            margin = size * 1.1 / 2
            ax.set_xlim([xcenter - margin, xcenter + margin])
            ax.set_ylim([ycenter - margin, ycenter + margin])

        # plot underlying edges
        vertex_size = 0.
        if highlight is not None:
            vertex_size = np.zeros(self.n_node)
            vertex_size[highlight] = .5

        # HACK in pygsp.plotting, remove alpha at lines 533 and 541
        self.pygsp.plotting['highlight_color'] = gnx_plot.green
        self.pygsp.plotting['normalize_intercept'] = 0.
        self.plot(
            edge_width=edge_width,  # highlight=highlights,
            edges=True,
            vertex_size=vertex_size,  # transparent nodes
            vertex_color=[(0., 0., 0., 0.)] * self.n_node,
            highlight=highlight,
            ax=ax)

        # plot distributions
        transparent_colors = [
            mpl.colors.to_hex(mpl.colors.to_rgba(c, alpha=.5), keep_alpha=True)
            for c in colors
        ]

        self.pygsp.plotting['normalize_intercept'] = 0.
        for distribution, color in zip(distributions, transparent_colors):
            self.plot(vertex_size=distribution,
                      vertex_color=color,
                      edge_width=0,
                      ax=ax)

        if with_edge_arrows:
            coords = self.coords
            arrows = self.max_edge_vector_per_node()
            coords = numpify(coords)
            arrows = numpify(arrows)
            ax.quiver(coords[:, 0],
                      coords[:, 1],
                      arrows[:, 0],
                      arrows[:, 1],
                      pivot='tail')

        ax.set_aspect('equal')
        return ax
Beispiel #8
0
 def plot(self, *args, **kwargs):
     """Calls pygsp.plot with numpified torch.Tensor arguments"""
     return self.pygsp.plot(*[numpify(a) for a in args],
                            **{k: numpify(v)
                               for k, v in kwargs.items()})
Beispiel #9
0
 def __init__(self, idx: Idx, prev_idx: Idx, coords: Coords, cell_size: float, priority: float):
     self.idx: Idx = idx
     self.prev_idx: Idx = prev_idx
     self.coords = u.numpify(coords)
     self.cell_size = cell_size
     self.priority = priority