def random_rotation_matrix(batch_shape=(), angle_stddev=0.06, angle_clip=0.18) -> tf.Tensor: # slightly different to the one used in pointnet2 # we use from_axis_angle rather than from_euler_angles # from tensorflow_graphics.geometry.transformation.rotation_matrix_3d \ # import from_axis_angle batch_shape = tuple(batch_shape) axis = tfrng.normal(shape=batch_shape + (3, )) axis = axis / tf.linalg.norm(axis, axis=-1, keepdims=True) angle = tfrng.normal(shape=batch_shape + (1, ), stddev=angle_stddev) if angle_clip: angle = tf.clip_by_value(angle, -angle_clip, angle_clip) return from_axis_angle(axis, angle)
def augment_image_example(image: tf.Tensor, label: tf.Tensor, sample_weight=None, noise_stddev=0): image = tf.cast(image, tf.float32) image = tf.image.per_image_standardization(image) if noise_stddev > 0: image = image + tfrng.normal(shape=tf.shape(image), stddev=noise_stddev) return tf.keras.utils.pack_x_y_sample_weight(image, label, sample_weight)
def jitter_positions(positions, stddev=0.01, clip=None): """ Randomly jitter points independantly by normally distributed noise. Args: positions: float array, any shape stddev: standard deviation of jitter clip: if not None, jittering is clipped to this """ if stddev == 0 or stddev is None: return positions jitter = tfrng.normal(shape=tf.shape(positions), stddev=stddev) if clip is not None: jitter = tf.clip_by_norm(jitter, clip, axes=[-1]) return positions + jitter
def tfrng_map_func(x): scale = tfrng.normal((), stddev=0.1, mean=1.0) shift = tfrng.uniform(()) return transform(x, scale, shift)
def random_rigid_transform_matrix(stddev=0.02, clip=None, dim=3) -> tf.Tensor: offset = tfrng.normal(shape=(dim, dim), stddev=stddev) if clip: offset = tf.clip_by_value(offset, -clip, clip) # pylint: disable=invalid-unary-operand-type return tf.eye(dim) + offset