Example #1
0
 def batch_norm_inference(module):
     # Note: scaling by a small value to increase numerical stability.
     x = tf_utils.uniform((4, 16)) * 1e-3
     mean = tf_utils.uniform((16, )) * 1e-3
     variance = tf_utils.uniform((16, ), low=0.0) * 1e-3
     offset = tf_utils.uniform((16, )) * 1e-3
     scale = tf_utils.uniform((16, )) * 1e-3
     module.batch_norm_inference(x, mean, variance, offset, scale)
Example #2
0
class QuantizationModule(tf_test_utils.TestModule):
    @tf_test_utils.tf_function_unit_test(
        input_signature=[tf.TensorSpec([32], tf.float32)],
        input_generator=lambda *args: tf_utils.uniform(*args, low=-6, high=6))
    def fake_quant(self, x):
        return tf.quantization.fake_quant_with_min_max_args(x,
                                                            min=-6,
                                                            max=6,
                                                            num_bits=8,
                                                            narrow_range=False,
                                                            name=None)
Example #3
0
 def basic_matmul(module):
     module.basic_matmul(tf_utils.uniform([4, 2]),
                         tf_utils.uniform([2, 4]))
Example #4
0
 def matmul_rhs_batch(module):
     module.matmul_rhs_batch(tf_utils.uniform([4, 2]),
                             tf_utils.uniform([3, 2, 4]))
Example #5
0
 def id_batch_size_1(module):
     i = tf_utils.uniform([1, 4, 5, 1])
     k = tf_utils.uniform([2, 2, 1, 1], dtype=np.float32)
     module.conv2d_1451x2211_dilated_valid(i, k)
Example #6
0
 def predict(module):
     module.predict(tf_utils.uniform(get_input_shape()))
Example #7
0
 def simple_matmul(module):
   # Note: scaling by a small value to increase numerical stability.
   a = tf_utils.uniform((128, 3072)) * 1e-3
   b = tf_utils.uniform((3072, 256)) * 1e-3
   module.simple_matmul(a, b)
Example #8
0
 def concat2axis(module):
   a = tf_utils.uniform([1, 5, 1])
   b = tf_utils.uniform([1, 5, 1])
   module.concat2axis(a, b)
 def matmul_dynamic_matching_batch(module):
   module.matmul_dynamic(
       tf_utils.uniform([2, 2, 3]), tf_utils.uniform([2, 3, 4]))
 def matmul_dynamic_rank_broadcasting(module):
   module.matmul_dynamic_lhs_batch(
       tf_utils.uniform([7, 2, 3]), tf_utils.uniform([3, 4]))
Example #11
0
 def predict(module):
     module.predict(tf_utils.uniform(get_input_shape()),
                    atol=1e-5,
                    rtol=1e-5)
Example #12
0
 def matmul_broadcast_singleton_dimension(module):
     module.matmul_broadcast_singleton_dimension(
         tf_utils.uniform([1, 4, 2]), tf_utils.uniform([3, 2, 4]))
Example #13
0
 "expm1":
     tf_test_utils.unit_test_specs_from_signatures(
         signature_shapes=UNARY_SIGNATURE_SHAPES,
         signature_dtypes=[tf.float32, tf.complex64]),
 "floor":
     tf_test_utils.unit_test_specs_from_signatures(
         signature_shapes=UNARY_SIGNATURE_SHAPES,
         signature_dtypes=[tf.float32]),
 "floordiv":
     tf_test_utils.unit_test_specs_from_signatures(
         signature_shapes=BINARY_SIGNATURE_SHAPES,
         signature_dtypes=[tf.float32, tf.int32],
         # Avoid integer division by 0.
         input_generators={
             "uniform_1_3":
                 lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0)
         }),
 "floormod":
     tf_test_utils.unit_test_specs_from_signatures(
         signature_shapes=BINARY_SIGNATURE_SHAPES,
         signature_dtypes=[tf.float32, tf.int32],
         # Avoid integer division by 0.
         input_generators={
             "uniform_1_3":
                 lambda *args: tf_utils.uniform(*args, low=1.0, high=3.0)
         }),
 "greater":
     tf_test_utils.unit_test_specs_from_signatures(
         signature_shapes=BINARY_SIGNATURE_SHAPES,
         signature_dtypes=[tf.float32, tf.int32]),
 "greater_equal":
Example #14
0
 def abs(module):
     module.fake_quant(tf_utils.uniform([32], low=-6, high=6))
Example #15
0
 def asymmetric_kernel(module):
     i = tf_utils.uniform([1, 4, 5, 1])
     k = np.array([[1, 4, 2], [-2, 0, 1]],
                  dtype=np.float32).reshape(2, 3, 1, 1)
     module.conv2d_1451x2311_valid(i, k)
Example #16
0
 def id_batch_size_2(module):
     i = tf_utils.uniform([2, 4, 5, 1])
     k = tf_utils.uniform([1, 1, 1, 1], dtype=np.float32)
     module.conv2d_2451x1111_valid(i, k)
Example #17
0
 def id_batch_size_1(module):
     i = tf_utils.uniform([1, 4, 5, 2])
     k = tf_utils.uniform([2, 2, 2, 3], dtype=np.float32)
     module.conv2d_1452x2223_dilated_valid(i, k)
 def matmul_high_rank_batch(module):
   module.matmul_high_rank_batch(
       tf_utils.uniform([1, 7, 4, 2]), tf_utils.uniform([7, 1, 2, 4]))
Example #19
0
 def basic_matmul(module):
   module.basic_matmul(tf_utils.uniform([LEFT_DIM, INNER_DIM]),
                       tf_utils.uniform([INNER_DIM, RIGHT_DIM]))
 def matmul_dynamic_broadcast_rhs(module):
   module.matmul_dynamic(
       tf_utils.uniform([2, 2, 3]), tf_utils.uniform([1, 3, 4]))
Example #21
0
 def matmul_rhs_batch(module):
   module.matmul_rhs_batch(
       tf_utils.uniform([LEFT_DIM, INNER_DIM]),
       tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
Example #22
0
 def dynamic_batch(module):
     x = tf_utils.uniform([3, 28 * 28]) * 1e-3
     module.predict(x)
Example #23
0
 def matmul_broadcast_singleton_dimension(module):
   module.matmul_broadcast_singleton_dimension(
       tf_utils.uniform([1, LEFT_DIM, INNER_DIM]),
       tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
Example #24
0
 def concat_zero_dim(module):
   a = tf_utils.uniform([1, 5, 0])
   b = tf_utils.uniform([1, 5, 1])
   module.concat_zero_dim(a, b)
Example #25
0
 def add_same_shape(module):
     lhs = tf_utils.uniform([4])
     rhs = tf_utils.uniform([4])
     module.add(lhs, rhs)
Example #26
0
 def add_broadcast_rhs(module):
     lhs = tf_utils.uniform([4])
     rhs = tf_utils.uniform([1])
     module.add(lhs, rhs)
 def predict(module):
     inputs = [tf_utils.uniform(shape) for shape in self._input_shapes]
     inputs = inputs[0] if len(inputs) == 1 else inputs
     module.predict(inputs, atol=1e-5)