def test_tf_scale_filters_types_and_shapes(self): output1 = search_space_utils.tf_scale_filters(32, 1 / 4, 8) self.assertEqual(output1.shape, []) self.assertEqual(output1.dtype, tf.int32) output2 = search_space_utils.tf_scale_filters( tf.constant(64, tf.int32), tf.constant(1 / 3, tf.float32), tf.constant(8, tf.int32)) self.assertEqual(output2.shape, []) self.assertEqual(output2.dtype, tf.int32)
def _compute_filters_for_multiplier( multiplier, input_filters_or_mask, filters_base): """Convert a FilterMultiplier to an integer (or int Tensor) filter size.""" if isinstance(input_filters_or_mask, int): input_filters = input_filters_or_mask return search_space_utils.scale_filters( input_filters, multiplier.scale, filters_base) elif isinstance(input_filters_or_mask, tf.Tensor): input_filters = tf.reduce_sum(tf.cast(input_filters_or_mask, tf.int32)) return search_space_utils.tf_scale_filters( input_filters, multiplier.scale, filters_base) else: raise ValueError('Unsupported type for input_filters_or_mask: {}'.format( input_filters_or_mask))
def test_tf_scale_filters_values(self): output1 = search_space_utils.tf_scale_filters(32, 1 / 4, 8) self.assertEqual(self.evaluate(output1), 8) output2 = search_space_utils.tf_scale_filters(64, 1 / 3, 8) self.assertEqual(self.evaluate(output2), 24) output3 = search_space_utils.tf_scale_filters(64, 1.4, 32) self.assertEqual(self.evaluate(output3), 96) output4 = search_space_utils.tf_scale_filters(64, 3, 32) self.assertEqual(self.evaluate(output4), 192) output5 = search_space_utils.tf_scale_filters(68, 1.0, 8) self.assertEqual(self.evaluate(output5), 72) output6 = search_space_utils.tf_scale_filters(68, 1.2, 8) self.assertEqual(self.evaluate(output6), 80) output7 = search_space_utils.tf_scale_filters(76, 1.0, 8) self.assertEqual(self.evaluate(output7), 80) output8 = search_space_utils.tf_scale_filters(76, 1.2, 8) self.assertEqual(self.evaluate(output8), 88)