def transform_source_batch(self, source, source_name): self.verify_axis_labels(('batch', 'channel', 'height', 'width'), self.data_stream.axis_labels[source_name], source_name) windowed_height, windowed_width = self.window_shape if (isinstance(source, list) or (isinstance(source, numpy.ndarray) and source.ndim == 1)) and all( isinstance(b, numpy.ndarray) and b.ndim == 3 for b in source): examples = [ self.transform_source_example(im, source_name) for im in source ] if isinstance(source, list): return examples else: return numpy.array(examples) elif isinstance(source, numpy.ndarray) and source.ndim == 4: # Hardcoded assumption of (batch, channels, height, width). # This is what the fast Cython code supports. batch_size = source.shape[0] image_height, image_width = source.shape[2:] if self.center_crop: # deterministic center crop offset_y = (image_height - windowed_height) // 2 offset_x = (image_width - windowed_width) // 2 out = source[:, :, offset_y:-offset_y, offset_x:-offset_x] else: # random crop out = numpy.empty(source.shape[:2] + self.window_shape, dtype=source.dtype) max_h_off = image_height - windowed_height max_w_off = image_width - windowed_width if max_h_off < 0 or max_w_off < 0: raise ValueError( "Got ndarray batch with image dimensions {} but " "requested window shape of {}".format( source.shape[2:], self.window_shape)) offsets_w = self.rng.random_integers(0, max_w_off, size=batch_size) offsets_h = self.rng.random_integers(0, max_h_off, size=batch_size) window_batch_bchw(source, offsets_h, offsets_w, out) if self.random_lr_flip: for example in out: if random.randint(0, 1): example[:] = example[:, :, ::-1] out = out.astype(numpy.float32) if self.devide_by_255: out = out.astype(numpy.float32) / 255.0 return out else: raise ValueError("uninterpretable batch format; expected a list " "of arrays with ndim = 3, or an array with " "ndim = 4")
def transform_source_batch(self, source, source_name): self.verify_axis_labels(('batch', 'channel', 'height', 'width'), self.data_stream.axis_labels[source_name], source_name) windowed_height, windowed_width = self.window_shape if isinstance(source, numpy.ndarray) and source.ndim == 4: # Hardcoded assumption of (batch, channels, height, width). # This is what the fast Cython code supports. out = numpy.empty(source.shape[:2] + self.window_shape, dtype=source.dtype) batch_size = source.shape[0] # If padding is requested, pad before making random crop. if self.pad is not None: symmetric = (self.pad, self.pad) paddims = ((0, 0), (0, 0), symmetric, symmetric) source = numpy.pad(source, paddims, 'constant') image_height, image_width = source.shape[2:] max_h_off = image_height - windowed_height max_w_off = image_width - windowed_width if max_h_off < 0 or max_w_off < 0: raise ValueError("Got ndarray batch with image dimensions {} " "but requested window shape of {}".format( source.shape[2:], self.window_shape)) offsets_w = self.rng.random_integers(0, max_w_off, size=batch_size) offsets_h = self.rng.random_integers(0, max_h_off, size=batch_size) window_batch_bchw(source, offsets_h, offsets_w, out) # If flipping is requested, randomly flip images horizontally if self.x_flip: whichflip = self.rng.binomial(1, 0.5, batch_size) out[whichflip, :, :, :] = out[whichflip, :, :, ::-1] return out elif all(isinstance(b, numpy.ndarray) and b.ndim == 3 for b in source): return [ self.transform_source_example(im, source_name) for im in source ] else: raise ValueError("uninterpretable batch format; expected a list " "of arrays with ndim = 3, or an array with " "ndim = 4")
def transform_source_batch(self, source, source_name): self.verify_axis_labels(('batch', 'channel', 'height', 'width'), self.data_stream.axis_labels[source_name], source_name) windowed_height, windowed_width = self.window_shape if isinstance(source, numpy.ndarray) and source.ndim == 4: # Hardcoded assumption of (batch, channels, height, width). # This is what the fast Cython code supports. out = numpy.empty(source.shape[:2] + self.window_shape, dtype=source.dtype) batch_size = source.shape[0] # If padding is requested, pad before making random crop. if self.pad is not None: symmetric = (self.pad, self.pad) paddims = ((0, 0), (0, 0), symmetric, symmetric) source = numpy.pad(source, paddims, 'constant') image_height, image_width = source.shape[2:] max_h_off = image_height - windowed_height max_w_off = image_width - windowed_width if max_h_off < 0 or max_w_off < 0: raise ValueError("Got ndarray batch with image dimensions {} " "but requested window shape of {}".format( source.shape[2:], self.window_shape)) offsets_w = self.rng.random_integers(0, max_w_off, size=batch_size) offsets_h = self.rng.random_integers(0, max_h_off, size=batch_size) window_batch_bchw(source, offsets_h, offsets_w, out) # If flipping is requested, randomly flip images horizontally if self.x_flip: whichflip = self.rng.binomial(1, 0.5, batch_size) out[whichflip,:,:,:] = out[whichflip,:,:,::-1] return out elif all(isinstance(b, numpy.ndarray) and b.ndim == 3 for b in source): return [self.transform_source_example(im, source_name) for im in source] else: raise ValueError("uninterpretable batch format; expected a list " "of arrays with ndim = 3, or an array with " "ndim = 4")