def test_get_strides_convolution(self):
        self.assertEqual((1, 1),
                         mobile_search_space_v3.get_strides(
                             mobile_search_space_v3.ConvSpec(kernel_size=3,
                                                             strides=1)))

        self.assertEqual((2, 2),
                         mobile_search_space_v3.get_strides(
                             mobile_search_space_v3.ConvSpec(kernel_size=3,
                                                             strides=2)))
    def test_get_strides_residual_connection(self):
        self.assertEqual(
            (1, 1),
            mobile_search_space_v3.get_strides(
                mobile_search_space_v3.ResidualSpec(basic_specs.ZeroSpec())))

        with self.assertRaisesRegex(ValueError,
                                    'Residual layer must have stride 1'):
            mobile_search_space_v3.get_strides(
                mobile_search_space_v3.ResidualSpec(
                    mobile_search_space_v3.GlobalAveragePoolSpec()))
    def test_get_strides_separable_convolution(self):
        self.assertEqual(
            (1, 1),
            mobile_search_space_v3.get_strides(
                mobile_search_space_v3.SeparableConvSpec(kernel_size=3,
                                                         strides=1,
                                                         activation='relu')))

        self.assertEqual(
            (2, 2),
            mobile_search_space_v3.get_strides(
                mobile_search_space_v3.SeparableConvSpec(kernel_size=3,
                                                         strides=2,
                                                         activation='relu')))
    def test_get_strides_depthwise_bottleneck(self):
        self.assertEqual((1, 1),
                         mobile_search_space_v3.get_strides(
                             mobile_search_space_v3.DepthwiseBottleneckSpec(
                                 kernel_size=3,
                                 expansion_filters=72,
                                 use_squeeze_and_excite=False,
                                 strides=1,
                                 activation='relu')))

        self.assertEqual((2, 2),
                         mobile_search_space_v3.get_strides(
                             mobile_search_space_v3.DepthwiseBottleneckSpec(
                                 kernel_size=3,
                                 expansion_filters=72,
                                 use_squeeze_and_excite=False,
                                 strides=2,
                                 activation='relu')))
    def test_get_strides_oneof(self):
        self.assertEqual(
            (1, 1),
            mobile_search_space_v3.get_strides(
                schema.OneOf([
                    mobile_search_space_v3.ConvSpec(kernel_size=3, strides=1),
                    mobile_search_space_v3.ConvSpec(kernel_size=5, strides=1),
                ], basic_specs.OP_TAG)))

        self.assertEqual(
            (2, 2),
            mobile_search_space_v3.get_strides(
                schema.OneOf([
                    mobile_search_space_v3.ConvSpec(kernel_size=3, strides=2),
                    mobile_search_space_v3.ConvSpec(kernel_size=5, strides=2),
                ], basic_specs.OP_TAG)))

        with self.assertRaisesRegex(ValueError, 'Stride mismatch'):
            mobile_search_space_v3.get_strides(
                schema.OneOf([
                    mobile_search_space_v3.ConvSpec(kernel_size=3, strides=1),
                    mobile_search_space_v3.ConvSpec(kernel_size=3, strides=2),
                ], basic_specs.OP_TAG))
 def test_get_strides_global_avg_pool(self):
     self.assertEqual((None, None),
                      mobile_search_space_v3.get_strides(
                          mobile_search_space_v3.GlobalAveragePoolSpec()))
 def test_get_strides_zero(self):
     self.assertEqual(
         (1, 1), mobile_search_space_v3.get_strides(basic_specs.ZeroSpec()))
 def test_get_strides_activation(self):
     self.assertEqual(
         (1, 1),
         mobile_search_space_v3.get_strides(mobile_search_space_v3.RELU))