def _restart(self): # make a deep copy data = [d.copy() for d in self.data_copy] label = [l.copy() for l in self.label_copy] label_mask = [l.copy() for l in self.label_mask_copy] # duplicate if necessary to fill batch num_samples = len(data) if num_samples < self.batch_size: idx = np.concatenate((np.tile(np.arange(num_samples), (self.batch_size // num_samples, )), np.random.permutation(num_samples)[:(self.batch_size % num_samples)]), axis=0) data, label_mask, label = [data[i] for i in idx], [label_mask[i] for i in idx], [label[i] for i in idx] num_samples = self.batch_size # shuffle samples idx = np.random.permutation(num_samples) data, label_mask, label = [data[i] for i in idx], [label_mask[i] for i in idx], [label[i] for i in idx] # sample to a fixed length for i in range(num_samples): k = len(data[i]) idx = np.concatenate((np.tile(np.arange(k), (self.sample_size // k, )), np.random.permutation(k)[:(self.sample_size % k)]), axis=0) data[i] = data[i][idx, :] label[i] = label[i][idx] data = np.concatenate(data, axis=0) # (NxS) x C label = np.concatenate(label, axis=0) # (NxS) label_mask = np.concatenate([l.reshape(1, -1, 1, 1) for l in label_mask], axis=0) # N x 50 x 1 x 1 # data aug. if self.jitter_rotation > 0: rotations = (('x', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0), ('y', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0), ('z', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0)) if 'x' in self.raw_dims: feat_idx = self.raw_dims.index('x') data[:, feat_idx:feat_idx + 3] = rotate_3d(data[:, feat_idx:feat_idx + 3], rotations) if 'nx' in self.raw_dims: feat_idx = self.raw_dims.index('nx') data[:, feat_idx:feat_idx + 3] = rotate_3d(data[:, feat_idx:feat_idx + 3], rotations) if self.jitter_stretch > 0: stretch_strength = (2 * np.random.rand(3) - 1) * self.jitter_stretch + 1 if 'x' in self.raw_dims: feat_idx = self.raw_dims.index('x') data[:, feat_idx:feat_idx + 3] *= stretch_strength if 'nx' in self.raw_dims: feat_idx = self.raw_dims.index('nx') data[:, feat_idx:feat_idx + 3] *= stretch_strength data[:, feat_idx:feat_idx + 3] /= np.sqrt( np.power(data[:, feat_idx:feat_idx + 3], 2).sum(axis=1, keepdims=True)) if self.jitter_xyz > 0 and 'x' in self.raw_dims: feat_idx = self.raw_dims.index('x') data[:, feat_idx:feat_idx + 3] += ((2 * np.random.rand(3) - 1) * self.jitter_xyz) # reshape and reset index self.data = data[:, self.feat_dims].reshape(num_samples, self.sample_size, -1, 1).transpose(0, 2, 3, 1) self.label = label.reshape(num_samples, self.sample_size, -1, 1).transpose(0, 2, 3, 1) self.label_mask = label_mask self.index = 0
def _data_processing_batch(self, data, label): data = np.concatenate(data, axis=0) # (NxS) x C label = np.concatenate(label, axis=0) # (NxS) # data aug. # TODO should this be done at batch level? if self.jitter_rotation > 0: rotations = (('x', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0), ('y', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0), ('z', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0)) if 'x' in self.raw_dims: feat_idx = self.raw_dims.index('x') data[:, feat_idx:feat_idx + 3] = rotate_3d(data[:, feat_idx:feat_idx + 3], rotations) if 'r' in self.raw_dims: feat_idx = self.raw_dims.index('r') data[:, feat_idx:feat_idx + 3] = rotate_3d(data[:, feat_idx:feat_idx + 3], rotations) if self.jitter_stretch > 0: stretch_strength = (2 * np.random.rand(3) - 1) * self.jitter_stretch + 1 if 'x' in self.raw_dims: feat_idx = self.raw_dims.index('x') data[:, feat_idx:feat_idx + 3] *= stretch_strength if 'r' in self.raw_dims: feat_idx = self.raw_dims.index('r') data[:, feat_idx:feat_idx + 3] *= stretch_strength data[:, feat_idx:feat_idx + 3] /= np.sqrt( np.power(data[:, feat_idx:feat_idx + 3], 2).sum(axis=1, keepdims=True)) if self.jitter_xyz > 0 and 'x' in self.raw_dims: feat_idx = self.raw_dims.index('x') data[:, feat_idx:feat_idx + 3] += ((2 * np.random.rand(3) - 1) * self.jitter_xyz) return data, label
def forward(self, bottom, top): points_per_batch = self.sample_size * self.batch_size data, label = self.data[self.idx:self.idx+points_per_batch].reshape(self.batch_size, self.sample_size, -1), \ self.label[self.idx:self.idx+points_per_batch] # jittering TODO: improve performance (e.g. move some computation to done at once for all samples?) for i in range(self.batch_size): if self.jitter_color is not None: feat_idx = self.raw_dims.index('r') data[i, :, feat_idx:feat_idx + 3] += np.random.randn(3).dot( self.jitter_color.T) # clipping to [0, 255] data[i, :, feat_idx:feat_idx + 3] = np.maximum( 0.0, np.minimum(255.0, data[i, :, feat_idx:feat_idx + 3])) if self.jitter_h is not None: feat_idx = self.raw_dims.index('h') data[i, :, feat_idx] += np.random.randn( self.sample_size) * self.jitter_h if self.jitter_rotation: rotations = (('z', np.random.rand() * np.pi / 8 - np.pi / 16), ('x', np.random.rand() * np.pi / 8 - np.pi / 16), ('y', np.random.rand() * np.pi * 2)) if 'x' in self.raw_dims: feat_idx = self.raw_dims.index('x') center = (np.mean(data[i, :, feat_idx]), np.max(data[i, :, feat_idx + 1]), np.mean(data[i, :, feat_idx + 2])) data[i, :, feat_idx:feat_idx + 3] = rotate_3d( data[i, :, feat_idx:feat_idx + 3], rotations, center) if 'nx' in self.raw_dims: feat_idx = self.raw_dims.index('nx') data[i, :, feat_idx:feat_idx + 3] = rotate_3d( data[i, :, feat_idx:feat_idx + 3], rotations) # slicing and scaling idxs, scs = [v[0] for v in self.feat_scales ], [v[1] for v in self.feat_scales] data = data[:, :, idxs] * np.array(scs) top[0].data[...] = data.reshape(self.batch_size, self.sample_size, -1, 1).transpose(0, 2, 3, 1) top[1].data[...] = label.reshape(self.batch_size, self.sample_size, -1, 1).transpose(0, 2, 3, 1) self.idx += points_per_batch if self.idx + points_per_batch > len(self.data): self._restart()