def transform_source_batch(self, source, source_name):
        self.verify_axis_labels(('batch', 'channel', 'height', 'width'),
                                self.data_stream.axis_labels[source_name],
                                source_name)
        windowed_height, windowed_width = self.window_shape

        if (isinstance(source, list) or
            (isinstance(source, numpy.ndarray) and source.ndim == 1)) and all(
                isinstance(b, numpy.ndarray) and b.ndim == 3 for b in source):
            examples = [
                self.transform_source_example(im, source_name) for im in source
            ]
            if isinstance(source, list):
                return examples
            else:
                return numpy.array(examples)
        elif isinstance(source, numpy.ndarray) and source.ndim == 4:
            # Hardcoded assumption of (batch, channels, height, width).
            # This is what the fast Cython code supports.
            batch_size = source.shape[0]
            image_height, image_width = source.shape[2:]

            if self.center_crop:  # deterministic center crop
                offset_y = (image_height - windowed_height) // 2
                offset_x = (image_width - windowed_width) // 2
                out = source[:, :, offset_y:-offset_y, offset_x:-offset_x]
            else:  # random crop
                out = numpy.empty(source.shape[:2] + self.window_shape,
                                  dtype=source.dtype)
                max_h_off = image_height - windowed_height
                max_w_off = image_width - windowed_width
                if max_h_off < 0 or max_w_off < 0:
                    raise ValueError(
                        "Got ndarray batch with image dimensions {} but "
                        "requested window shape of {}".format(
                            source.shape[2:], self.window_shape))
                offsets_w = self.rng.random_integers(0,
                                                     max_w_off,
                                                     size=batch_size)
                offsets_h = self.rng.random_integers(0,
                                                     max_h_off,
                                                     size=batch_size)
                window_batch_bchw(source, offsets_h, offsets_w, out)

            if self.random_lr_flip:
                for example in out:
                    if random.randint(0, 1):
                        example[:] = example[:, :, ::-1]

            out = out.astype(numpy.float32)

            if self.devide_by_255:
                out = out.astype(numpy.float32) / 255.0

            return out
        else:
            raise ValueError("uninterpretable batch format; expected a list "
                             "of arrays with ndim = 3, or an array with "
                             "ndim = 4")
Exemple #2
0
 def transform_source_batch(self, source, source_name):
     self.verify_axis_labels(('batch', 'channel', 'height', 'width'),
                             self.data_stream.axis_labels[source_name],
                             source_name)
     windowed_height, windowed_width = self.window_shape
     if isinstance(source, numpy.ndarray) and source.ndim == 4:
         # Hardcoded assumption of (batch, channels, height, width).
         # This is what the fast Cython code supports.
         out = numpy.empty(source.shape[:2] + self.window_shape,
                           dtype=source.dtype)
         batch_size = source.shape[0]
         # If padding is requested, pad before making random crop.
         if self.pad is not None:
             symmetric = (self.pad, self.pad)
             paddims = ((0, 0), (0, 0), symmetric, symmetric)
             source = numpy.pad(source, paddims, 'constant')
         image_height, image_width = source.shape[2:]
         max_h_off = image_height - windowed_height
         max_w_off = image_width - windowed_width
         if max_h_off < 0 or max_w_off < 0:
             raise ValueError("Got ndarray batch with image dimensions {} "
                              "but requested window shape of {}".format(
                                  source.shape[2:], self.window_shape))
         offsets_w = self.rng.random_integers(0, max_w_off, size=batch_size)
         offsets_h = self.rng.random_integers(0, max_h_off, size=batch_size)
         window_batch_bchw(source, offsets_h, offsets_w, out)
         # If flipping is requested, randomly flip images horizontally
         if self.x_flip:
             whichflip = self.rng.binomial(1, 0.5, batch_size)
             out[whichflip, :, :, :] = out[whichflip, :, :, ::-1]
         return out
     elif all(isinstance(b, numpy.ndarray) and b.ndim == 3 for b in source):
         return [
             self.transform_source_example(im, source_name) for im in source
         ]
     else:
         raise ValueError("uninterpretable batch format; expected a list "
                          "of arrays with ndim = 3, or an array with "
                          "ndim = 4")
Exemple #3
0
 def transform_source_batch(self, source, source_name):
     self.verify_axis_labels(('batch', 'channel', 'height', 'width'),
                             self.data_stream.axis_labels[source_name],
                             source_name)
     windowed_height, windowed_width = self.window_shape
     if isinstance(source, numpy.ndarray) and source.ndim == 4:
         # Hardcoded assumption of (batch, channels, height, width).
         # This is what the fast Cython code supports.
         out = numpy.empty(source.shape[:2] + self.window_shape,
                           dtype=source.dtype)
         batch_size = source.shape[0]
         # If padding is requested, pad before making random crop.
         if self.pad is not None:
             symmetric = (self.pad, self.pad)
             paddims = ((0, 0), (0, 0), symmetric, symmetric)
             source = numpy.pad(source, paddims, 'constant')
         image_height, image_width = source.shape[2:]
         max_h_off = image_height - windowed_height
         max_w_off = image_width - windowed_width
         if max_h_off < 0 or max_w_off < 0:
             raise ValueError("Got ndarray batch with image dimensions {} "
                              "but requested window shape of {}".format(
                                  source.shape[2:], self.window_shape))
         offsets_w = self.rng.random_integers(0, max_w_off, size=batch_size)
         offsets_h = self.rng.random_integers(0, max_h_off, size=batch_size)
         window_batch_bchw(source, offsets_h, offsets_w, out)
         # If flipping is requested, randomly flip images horizontally
         if self.x_flip:
             whichflip = self.rng.binomial(1, 0.5, batch_size)
             out[whichflip,:,:,:] = out[whichflip,:,:,::-1]
         return out
     elif all(isinstance(b, numpy.ndarray) and b.ndim == 3 for b in source):
         return [self.transform_source_example(im, source_name)
                 for im in source]
     else:
         raise ValueError("uninterpretable batch format; expected a list "
                          "of arrays with ndim = 3, or an array with "
                          "ndim = 4")