def test_tf_scale_filters_types_and_shapes(self):
        output1 = search_space_utils.tf_scale_filters(32, 1 / 4, 8)
        self.assertEqual(output1.shape, [])
        self.assertEqual(output1.dtype, tf.int32)

        output2 = search_space_utils.tf_scale_filters(
            tf.constant(64, tf.int32), tf.constant(1 / 3, tf.float32),
            tf.constant(8, tf.int32))
        self.assertEqual(output2.shape, [])
        self.assertEqual(output2.dtype, tf.int32)
Ejemplo n.º 2
0
def _compute_filters_for_multiplier(
    multiplier,
    input_filters_or_mask,
    filters_base):
  """Convert a FilterMultiplier to an integer (or int Tensor) filter size."""
  if isinstance(input_filters_or_mask, int):
    input_filters = input_filters_or_mask
    return search_space_utils.scale_filters(
        input_filters, multiplier.scale, filters_base)
  elif isinstance(input_filters_or_mask, tf.Tensor):
    input_filters = tf.reduce_sum(tf.cast(input_filters_or_mask, tf.int32))
    return search_space_utils.tf_scale_filters(
        input_filters, multiplier.scale, filters_base)
  else:
    raise ValueError('Unsupported type for input_filters_or_mask: {}'.format(
        input_filters_or_mask))
    def test_tf_scale_filters_values(self):
        output1 = search_space_utils.tf_scale_filters(32, 1 / 4, 8)
        self.assertEqual(self.evaluate(output1), 8)

        output2 = search_space_utils.tf_scale_filters(64, 1 / 3, 8)
        self.assertEqual(self.evaluate(output2), 24)

        output3 = search_space_utils.tf_scale_filters(64, 1.4, 32)
        self.assertEqual(self.evaluate(output3), 96)

        output4 = search_space_utils.tf_scale_filters(64, 3, 32)
        self.assertEqual(self.evaluate(output4), 192)

        output5 = search_space_utils.tf_scale_filters(68, 1.0, 8)
        self.assertEqual(self.evaluate(output5), 72)

        output6 = search_space_utils.tf_scale_filters(68, 1.2, 8)
        self.assertEqual(self.evaluate(output6), 80)

        output7 = search_space_utils.tf_scale_filters(76, 1.0, 8)
        self.assertEqual(self.evaluate(output7), 80)

        output8 = search_space_utils.tf_scale_filters(76, 1.2, 8)
        self.assertEqual(self.evaluate(output8), 88)