예제 #1
0
    def test_imageToInput(self):
        """Uint8 image must becomes a valid network input."""
        num_layers = 7
        im_dim = Shape(3, 512, 512)
        img = 100 * np.ones(im_dim.hwc(), np.uint8)

        # Create a network (parameters do not matter).
        net = orpac_net.Orpac(self.sess, im_dim, num_layers, 10, None, False)

        # Image must be converted to float32 CHW image with leading
        # batch dimension of 1. All values must have been divided by 255.
        img_wl = net._imageToInput(img)
        dim_xin = Shape(int(net._xin.shape[1]), *net.outputShape().hw())
        assert img_wl.dtype == np.float32
        assert img_wl.shape == (1, *dim_xin.chw())
예제 #2
0
class Orpac:
    # Specify how many times the decompose the input image with Wavelets.
    _NUM_WAVELET_DECOMPOSITIONS = 3

    def __init__(self, sess, im_dim, num_layers, num_classes, bw_init, train):
        # Decide if we want to create cost nodes or not.
        assert isinstance(train, bool)

        # Backup basic variables.
        self._trainable = train
        self.sess = sess
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.im_dim = im_dim

        # Create placeholder variable for Wavelet decomposed image.
        self._xin = self._createInputTensor(im_dim)

        # Setup the NMS nodes and Orpac network.
        self._setupNonMaxSuppression()
        with tf.variable_scope('orpac'):
            self.out = self._setupNetwork(self._xin, bw_init, np.float32)

        # Store shape of the output tensor.
        self.ft_dim = Shape(*self.out.shape.as_list()[1:])

        # Define the cost nodes and compile them into a dictionary if this
        # network is trainable, otherwise do nothing.
        if self._trainable:
            self._cost_nodes, self._optimiser = self._addOptimiser()
        else:
            self._cost_nodes, self._optimiser = {}, None

    def session(self):
        """Return Tensorflow session"""
        return self.sess

    def getBias(self, layer):
        g = tf.get_default_graph().get_tensor_by_name
        return self.sess.run(g(f'orpac/b{layer}:0'))

    def getWeight(self, layer):
        g = tf.get_default_graph().get_tensor_by_name
        return self.sess.run(g(f'orpac/W{layer}:0'))

    def numLayers(self):
        return self.num_layers

    def numClasses(self):
        return self.num_classes

    def outputShape(self):
        """Return the shape of the network output (exclusive Batch dimension).

        For example, the output may be Shape(chan=18, height=64, width=64).
        """
        # Sanity check: the number of output channels must match the value
        # returned by `numOutputChannels`.
        assert self.ft_dim.chan == self.numOutputChannels(self.numClasses())
        return self.ft_dim.copy()

    def imageShape(self):
        return self.im_dim.copy()

    def output(self):
        return self.out

    def trainable(self):
        return self._trainable

    def costNodes(self):
        return dict(self._cost_nodes)

    @staticmethod
    def numOutputChannels(num_classes: int):
        """Return the number of feature channels when there are `num_classes`.

        This value specifes the number of channels that the final network layer
        will return.

        NOTE: this returns the same value as `featureShape.chan` but does not
        require an Orpac instance since it is a class method.

        Input:
            num_classes: int
                The number of output channels depends on the number of classes
                in the data set. This variables specifes that number.

        Returns:
            int: number of channels in final network output layer.
        """
        return 4 + 2 + num_classes

    @staticmethod
    def setBBoxRects(y, val):
        y = np.array(y)
        assert y.ndim == 3
        assert np.array(val).shape == y[:4].shape
        y[:4] = val
        return y

    @staticmethod
    def getBBoxRects(y):
        assert y.ndim == 3
        return y[:4]

    @staticmethod
    def setIsFg(y, val):
        y = np.array(y)
        assert y.ndim == 3
        assert np.array(val).shape == y[4:6].shape
        y[4:6] = val
        return y

    @staticmethod
    def getIsFg(y):
        assert y.ndim == 3
        return y[4:6]

    @staticmethod
    def setClassLabel(y, val):
        y = np.array(y)
        assert y.ndim == 3
        assert np.array(val).shape == y[6:].shape
        y[6:] = val
        return y

    @staticmethod
    def getClassLabel(y):
        assert y.ndim == 3
        return y[6:]

    def _createInputTensor(self, im_dim):
        N = self._NUM_WAVELET_DECOMPOSITIONS

        im_dim = np.array(im_dim.hw()) / (2**N)
        width, height = im_dim.astype(np.int32).tolist()

        num_chan = 3 * (4**N)
        x_dim = (1, num_chan, height, width)
        return tf.placeholder(tf.float32, x_dim, name='x_in')

    def _addOptimiser(self):
        cost = createCostNodes(self.out)
        g = tf.get_default_graph().get_tensor_by_name
        lrate_in = tf.placeholder(tf.float32, name='lrate')
        opt = tf.train.AdamOptimizer(learning_rate=lrate_in).minimize(cost)
        nodes = {
            'cls': g(f'orpac-cost/cls:0'),
            'bbox': g(f'orpac-cost/bbox:0'),
            'isFg': g(f'orpac-cost/isFg:0'),
            'total': g(f'orpac-cost/total:0'),
        }
        return nodes, opt

    def _imageToInput(self, img):
        """Return Wavelet decomposed `img`

        The returned tensor is compatible with this class' `_xin` placeholder.

        The image dimensions must match those returned by `imageShape`, ie.
        it must be square, RGB and all its dimension must be powers of 2.

        Each colour channel will be decomposed self._NUM_WAVELET_DECOMPOSITIONS
        times.

        Inputs:
            img: UInt8 Array[height, width, 3]

        Output:
            Array[1, *imageToWaveletDim(img_shape)]
                The output dimension depends on the number of decompositions
                and the input size. For a 512x512x3 image with 3 decompositions
                the output would have Shape(chan=192, height=64, width=64).
        """
        # Sanity check.
        assert isinstance(img, np.ndarray) and img.dtype == np.uint8

        im_dim = self.imageShape()
        assert img.shape == im_dim.hwc()
        assert im_dim.isSquare() and im_dim.isPow2()

        # Normalise the image and put each colour channels as a separate image
        # into a work list.
        img = img.astype(np.float32) / 255
        src = list(img.transpose([2, 0, 1]))

        # Decompose the each channel.
        for i in range(self._NUM_WAVELET_DECOMPOSITIONS):
            N = im_dim.width >> (i + 1)

            # Apply wavelet transform to every image in the worklist and place
            # the results in an output list.
            dst = []
            while len(src) > 0:
                cA, (cH, cV, cD) = pywt.dwt2(src.pop(),
                                             'db2',
                                             mode='symmetric')

                # All coefficients must be square and have identical dimensions.
                assert cA.shape == cH.shape == cV.shape == cD.shape
                assert cA.ndim == 2 and cA.shape[0] == cA.shape[1]

                # The wavelet decomposition reduces dimension by roughly 2.
                # However, due to transients the outputs are a bit larger than
                # that which is why we must trim them. Here we compute the
                # start/stop indices for the trimming.
                excess = cA.shape[0] - N
                assert excess >= 0
                a = excess // 2
                b = a + N
                assert b <= cA.shape[0]

                # Trim the coefficients.
                dst.append(cA[a:b, a:b])
                dst.append(cH[a:b, a:b])
                dst.append(cV[a:b, a:b])
                dst.append(cD[a:b, a:b])

            # Copy the output into the new work list and repeat the process.
            src = dst

        # Convert the Python list to Numpy and verify its shape.
        data = np.array(src, np.float32)
        assert data.shape == imageToWaveletDim(im_dim).chw()

        # Return the decomposed image with the leading batch dimension.
        return np.expand_dims(data, 0)

    def _setupNetwork(self, x_in, bw_init, dtype):
        # Convenience: shared arguments conv2d.
        opts = dict(padding='SAME', data_format='NCHW', strides=[1, 1, 1, 1])
        num_ft_chan = 64

        # Hidden conv layers.
        # Examples dimensions assume 128x128 RGB images.
        # Input : [-1, 3, 128, 128] ---> [-1, 64, 128, 128]
        # Kernel: 3x3  Features: 64
        prev = x_in
        for i in range(self.num_layers - 1):
            prev_shape = tuple(prev.shape.as_list())
            b_dim = (num_ft_chan, 1, 1)
            W_dim = (3, 3, prev_shape[1], num_ft_chan)
            b, W = unpackBiasAndWeight(bw_init, b_dim, W_dim, i, dtype)

            prev = tf.nn.relu(tf.nn.conv2d(prev, W, **opts) + b)
            del i, b, W, b_dim, W_dim

        # Conv output layer to learn the BBoxes and class labels.
        # Shape: [-1, 64, 64, 64] ---> [-1, num_out_chan, 64, 64]
        # Kernel: 33x33
        num_out_chan = self.numOutputChannels(self.num_classes)
        prev_shape = tuple(prev.shape.as_list())
        b_dim = (num_out_chan, 1, 1)
        W_dim = (33, 33, prev.shape[1], num_out_chan)
        b, W = unpackBiasAndWeight(bw_init, b_dim, W_dim, self.num_layers - 1,
                                   dtype)
        return tf.add(tf.nn.conv2d(prev, W, **opts), b, name='out')

    def _setupNonMaxSuppression(self):
        """Create non-maximum-suppression nodes.

        These are irrelevant for training but useful in the predictor to cull
        the flood of possible bounding boxes.
        """
        with tf.variable_scope('non-max-suppression'):
            r_in = tf.placeholder(tf.float32, [None, 4], name='bb_rects')
            s_in = tf.placeholder(tf.float32, [None], name='scores')
            tf.image.non_max_suppression(r_in, s_in, 30, 0.2, name='op')

    def nonMaxSuppression(self, bb_rects, scores):
        """ Wrapper around Tensorflow's non-max-suppression function.

        Input:
            sess: Tensorflow sessions
            bb_rects: Array[N, 4]
                BBox rectangles, one per column.
            scores: Array[N]
                One scalar score for each BBox.

        Returns:
            idx: Array
                List of BBox indices that survived the operation.
        """
        g = tf.get_default_graph().get_tensor_by_name
        fd = {
            g('non-max-suppression/scores:0'): scores,
            g('non-max-suppression/bb_rects:0'): bb_rects,
        }
        return self.sess.run(g('non-max-suppression/op:0'), feed_dict=fd)

    def train(self, img, y, lrate, mask_cls, mask_bbox, mask_isFg):
        assert self._trainable

        # Sanity checks
        assert lrate > 0
        assert mask_cls.shape == mask_bbox.shape == mask_isFg.shape
        assert y.shape == self.ft_dim.chw()
        assert y.shape[1:] == mask_cls.shape

        # Feed dictionary.
        g = tf.get_default_graph().get_tensor_by_name
        fd = {
            self._xin: self._imageToInput(img),
            g(f'lrate:0'): lrate,
            g(f'orpac-cost/y_true:0'): np.expand_dims(y, 0),
            g(f'orpac-cost/mask_cls:0'): mask_cls,
            g(f'orpac-cost/mask_bbox:0'): mask_bbox,
            g(f'orpac-cost/mask_isFg:0'): mask_isFg,
        }

        # Run one optimisation step and return the costs.
        nodes = [self._cost_nodes, self._optimiser]
        costs, _ = self.sess.run(nodes, feed_dict=fd)
        return costs

    def predict(self, img):
        # Run predictor network.
        g = tf.get_default_graph().get_tensor_by_name
        out = self.sess.run(g(f'orpac/out:0'),
                            feed_dict={self._xin: self._imageToInput(img)})
        assert out.ndim == 4 and out.shape[0] == 1
        return out[0]

    def serialise(self):
        out = {'weight': {}, 'bias': {}, 'num-layers': self.numLayers()}
        for i in range(self.num_layers):
            out['bias'][i] = self.getBias(i)
            out['weight'][i] = self.getWeight(i)
        return out