예제 #1
0
    def _tile_figs(self, space_width):
        n = int(np.ceil(np.sqrt(self.figs.shape[0])))
        padding = (((0, n**2 - self.figs.shape[0]), (0, space_width),
                    (0, space_width))  # add some space between filters
                   + ((0, 0), ) * (self.figs.ndim - 3)
                   )  # don't pad the last dimension (if there is one)
        if is_intarr(self.figs):
            constant_values = 255
        else:
            constant_values = 1.

        tiled_figs = np.pad(self.figs,
                            padding,
                            mode='constant',
                            constant_values=constant_values)  # pad with white

        # tile the filters into an image
        tiled_figs = tiled_figs.reshape(
            (n, n) + tiled_figs.shape[1:]).transpose(
                (0, 2, 1, 3) + tuple(range(4, tiled_figs.ndim + 1)))
        tiled_figs = tiled_figs.reshape((n * tiled_figs.shape[1],
                                         n * tiled_figs.shape[3]) +
                                        tiled_figs.shape[4:])

        return tiled_figs[:-space_width, :-space_width,
                          ...]  # Delete the padding border
예제 #2
0
    def _tile_figs(self, space_width):
        padding = (((0, 0),
                    (0, space_width)) + ((0, 0), ) * (self.figs.ndim - 2))
        if is_intarr(self.figs):
            constant_values = 255
        else:
            constant_values = 1.

        tiled_figs = np.pad(self.figs,
                            padding,
                            mode='constant',
                            constant_values=constant_values)

        return tiled_figs.reshape((tiled_figs.shape[0] *
                                   tiled_figs.shape[1], ) +
                                  tiled_figs.shape[2:])[:-space_width, ...]
예제 #3
0
    def _tile_figs(self, space_width):
        padding = (((0, 0), (0, 0),
                    (0, space_width)) + ((0, 0), ) * (self.figs.ndim - 3))
        if is_intarr(self.figs):
            constant_values = 255
        else:
            constant_values = 1.

        tiled_figs = np.pad(self.figs,
                            padding,
                            mode='constant',
                            constant_values=constant_values)

        tiled_figs = tiled_figs.transpose((1, 0, 2) +
                                          tuple(range(3, tiled_figs.ndim)))
        tile_figs = tiled_figs.reshape(
            (tiled_figs.shape[0], tiled_figs.shape[1] * tiled_figs.shape[2]) +
            tiled_figs.shape[3:])

        return tile_figs[:, :-space_width, ...]
예제 #4
0
    def _tile_figs(self, space_width):
        if self.row_num is None and self.col_num is None:
            raise ValueError(
                f"row_num={self.row_num} or col_num={self.col_num} should not be int, not None"
            )

        figs_num = self.figs.shape[0]
        if self.row_num is None:
            self.row_num = int(np.ceil(figs_num / self.col_num))
        elif self.col_num is None:
            self.col_num = int(np.ceil(figs_num / self.row_num))

        n = self.row_num * self.col_num

        if n < self.figs.shape[0]:
            raise ValueError(f"row_num * col_num = {n} < self.figs.shape[0]")

        padding = (((0, n - self.figs.shape[0]), (0, space_width),
                    (0, space_width))  # add some space between filters
                   + ((0, 0), ) * (self.figs.ndim - 3)
                   )  # don't pad the last dimension (if there is one)
        if is_intarr(self.figs):
            constant_values = 255
        else:
            constant_values = 1.

        tiled_figs = np.pad(self.figs,
                            padding,
                            mode='constant',
                            constant_values=constant_values)  # pad with white

        # tile the filters into an image
        tiled_figs = tiled_figs.reshape(
            (self.row_num, self.col_num) + tiled_figs.shape[1:]).transpose(
                (0, 2, 1, 3) + tuple(range(4, tiled_figs.ndim + 1)))
        tiled_figs = tiled_figs.reshape((self.row_num * tiled_figs.shape[1],
                                         self.col_num * tiled_figs.shape[3]) +
                                        tiled_figs.shape[4:])

        return tiled_figs[:-space_width, :-space_width,
                          ...]  # Delete the padding border
def resolve_loc_cue_conflict_by_priority(loc_cue: np.ndarray, priority: List[np.ndarray]) -> np.ndarray:
    """Solve conflict in candidate seed according to priority between channels.

    Args:
        loc_cue: (C, H, W) numpy array. Where (c, h, w) > 0 means channel c is one of the channel of the (h, w)
            pixel in sample n
        priority: Priority among channels. priority[index] = channels who prior to index'th channel

    Returns: (C, H, W) numpy array.
    """
    assert loc_cue.ndim == 3
    assert is_intarr(loc_cue)

    assert len(priority) == loc_cue.shape[0]

    one_hot_seed = loc_cue.copy()
    for ch_idx, prior_channels in enumerate(priority):
        one_hot_seed[ch_idx][np.sum(loc_cue[prior_channels, ...], axis=0) > 0] = 0

    assert np.all((np.sum(one_hot_seed, axis=0) <= 1) & np.sum(one_hot_seed, axis=0) >= 0)

    return one_hot_seed