def _restart(self):
        # make a deep copy
        data = [d.copy() for d in self.data_copy]
        label = [l.copy() for l in self.label_copy]
        label_mask = [l.copy() for l in self.label_mask_copy]

        # duplicate if necessary to fill batch
        num_samples = len(data)
        if num_samples < self.batch_size:
            idx = np.concatenate((np.tile(np.arange(num_samples), (self.batch_size // num_samples, )),
                                  np.random.permutation(num_samples)[:(self.batch_size % num_samples)]), axis=0)
            data, label_mask, label = [data[i] for i in idx], [label_mask[i] for i in idx], [label[i] for i in idx]
            num_samples = self.batch_size

        # shuffle samples
        idx = np.random.permutation(num_samples)
        data, label_mask, label = [data[i] for i in idx], [label_mask[i] for i in idx], [label[i] for i in idx]

        # sample to a fixed length
        for i in range(num_samples):
            k = len(data[i])
            idx = np.concatenate((np.tile(np.arange(k), (self.sample_size // k, )),
                                  np.random.permutation(k)[:(self.sample_size % k)]), axis=0)
            data[i] = data[i][idx, :]
            label[i] = label[i][idx]

        data = np.concatenate(data, axis=0)     # (NxS) x C
        label = np.concatenate(label, axis=0)   # (NxS)
        label_mask = np.concatenate([l.reshape(1, -1, 1, 1) for l in label_mask], axis=0)     # N x 50 x 1 x 1

        # data aug.
        if self.jitter_rotation > 0:
            rotations = (('x', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0),
                         ('y', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0),
                         ('z', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0))
            if 'x' in self.raw_dims:
                feat_idx = self.raw_dims.index('x')
                data[:, feat_idx:feat_idx + 3] = rotate_3d(data[:, feat_idx:feat_idx + 3], rotations)
            if 'nx' in self.raw_dims:
                feat_idx = self.raw_dims.index('nx')
                data[:, feat_idx:feat_idx + 3] = rotate_3d(data[:, feat_idx:feat_idx + 3], rotations)
        if self.jitter_stretch > 0:
            stretch_strength = (2 * np.random.rand(3) - 1) * self.jitter_stretch + 1
            if 'x' in self.raw_dims:
                feat_idx = self.raw_dims.index('x')
                data[:, feat_idx:feat_idx + 3] *= stretch_strength
            if 'nx' in self.raw_dims:
                feat_idx = self.raw_dims.index('nx')
                data[:, feat_idx:feat_idx + 3] *= stretch_strength
                data[:, feat_idx:feat_idx + 3] /= np.sqrt(
                    np.power(data[:, feat_idx:feat_idx + 3], 2).sum(axis=1, keepdims=True))
        if self.jitter_xyz > 0 and 'x' in self.raw_dims:
            feat_idx = self.raw_dims.index('x')
            data[:, feat_idx:feat_idx + 3] += ((2 * np.random.rand(3) - 1) * self.jitter_xyz)

        # reshape and reset index
        self.data = data[:, self.feat_dims].reshape(num_samples, self.sample_size, -1, 1).transpose(0, 2, 3, 1)
        self.label = label.reshape(num_samples, self.sample_size, -1, 1).transpose(0, 2, 3, 1)
        self.label_mask = label_mask
        self.index = 0
Exemplo n.º 2
0
    def _data_processing_batch(self, data, label):
        data = np.concatenate(data, axis=0)     # (NxS) x C
        label = np.concatenate(label, axis=0)   # (NxS)

        # data aug. # TODO should this be done at batch level?
        if self.jitter_rotation > 0:
            rotations = (('x', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0),
                         ('y', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0),
                         ('z', (2 * np.random.rand() - 1) * self.jitter_rotation * np.pi / 180.0))
            if 'x' in self.raw_dims:
                feat_idx = self.raw_dims.index('x')
                data[:, feat_idx:feat_idx + 3] = rotate_3d(data[:, feat_idx:feat_idx + 3], rotations)
            if 'r' in self.raw_dims:
                feat_idx = self.raw_dims.index('r')
                data[:, feat_idx:feat_idx + 3] = rotate_3d(data[:, feat_idx:feat_idx + 3], rotations)
        if self.jitter_stretch > 0:
            stretch_strength = (2 * np.random.rand(3) - 1) * self.jitter_stretch + 1
            if 'x' in self.raw_dims:
                feat_idx = self.raw_dims.index('x')
                data[:, feat_idx:feat_idx + 3] *= stretch_strength
            if 'r' in self.raw_dims:
                feat_idx = self.raw_dims.index('r')
                data[:, feat_idx:feat_idx + 3] *= stretch_strength
                data[:, feat_idx:feat_idx + 3] /= np.sqrt(
                    np.power(data[:, feat_idx:feat_idx + 3], 2).sum(axis=1, keepdims=True))
        if self.jitter_xyz > 0 and 'x' in self.raw_dims:
            feat_idx = self.raw_dims.index('x')
            data[:, feat_idx:feat_idx + 3] += ((2 * np.random.rand(3) - 1) * self.jitter_xyz)
        return data, label
    def forward(self, bottom, top):
        points_per_batch = self.sample_size * self.batch_size
        data, label = self.data[self.idx:self.idx+points_per_batch].reshape(self.batch_size, self.sample_size, -1), \
                      self.label[self.idx:self.idx+points_per_batch]

        # jittering TODO: improve performance (e.g. move some computation to done at once for all samples?)
        for i in range(self.batch_size):
            if self.jitter_color is not None:
                feat_idx = self.raw_dims.index('r')
                data[i, :, feat_idx:feat_idx + 3] += np.random.randn(3).dot(
                    self.jitter_color.T)
                # clipping to [0, 255]
                data[i, :, feat_idx:feat_idx + 3] = np.maximum(
                    0.0, np.minimum(255.0, data[i, :, feat_idx:feat_idx + 3]))
            if self.jitter_h is not None:
                feat_idx = self.raw_dims.index('h')
                data[i, :, feat_idx] += np.random.randn(
                    self.sample_size) * self.jitter_h
            if self.jitter_rotation:
                rotations = (('z', np.random.rand() * np.pi / 8 - np.pi / 16),
                             ('x', np.random.rand() * np.pi / 8 - np.pi / 16),
                             ('y', np.random.rand() * np.pi * 2))
                if 'x' in self.raw_dims:
                    feat_idx = self.raw_dims.index('x')
                    center = (np.mean(data[i, :, feat_idx]),
                              np.max(data[i, :, feat_idx + 1]),
                              np.mean(data[i, :, feat_idx + 2]))
                    data[i, :, feat_idx:feat_idx + 3] = rotate_3d(
                        data[i, :, feat_idx:feat_idx + 3], rotations, center)
                if 'nx' in self.raw_dims:
                    feat_idx = self.raw_dims.index('nx')
                    data[i, :, feat_idx:feat_idx + 3] = rotate_3d(
                        data[i, :, feat_idx:feat_idx + 3], rotations)

        # slicing and scaling
        idxs, scs = [v[0] for v in self.feat_scales
                     ], [v[1] for v in self.feat_scales]
        data = data[:, :, idxs] * np.array(scs)

        top[0].data[...] = data.reshape(self.batch_size, self.sample_size, -1,
                                        1).transpose(0, 2, 3, 1)
        top[1].data[...] = label.reshape(self.batch_size, self.sample_size, -1,
                                         1).transpose(0, 2, 3, 1)

        self.idx += points_per_batch
        if self.idx + points_per_batch > len(self.data):
            self._restart()