def test_get_strides_convolution(self): self.assertEqual((1, 1), mobile_search_space_v3.get_strides( mobile_search_space_v3.ConvSpec(kernel_size=3, strides=1))) self.assertEqual((2, 2), mobile_search_space_v3.get_strides( mobile_search_space_v3.ConvSpec(kernel_size=3, strides=2)))
def test_get_strides_residual_connection(self): self.assertEqual( (1, 1), mobile_search_space_v3.get_strides( mobile_search_space_v3.ResidualSpec(basic_specs.ZeroSpec()))) with self.assertRaisesRegex(ValueError, 'Residual layer must have stride 1'): mobile_search_space_v3.get_strides( mobile_search_space_v3.ResidualSpec( mobile_search_space_v3.GlobalAveragePoolSpec()))
def test_get_strides_separable_convolution(self): self.assertEqual( (1, 1), mobile_search_space_v3.get_strides( mobile_search_space_v3.SeparableConvSpec(kernel_size=3, strides=1, activation='relu'))) self.assertEqual( (2, 2), mobile_search_space_v3.get_strides( mobile_search_space_v3.SeparableConvSpec(kernel_size=3, strides=2, activation='relu')))
def test_get_strides_depthwise_bottleneck(self): self.assertEqual((1, 1), mobile_search_space_v3.get_strides( mobile_search_space_v3.DepthwiseBottleneckSpec( kernel_size=3, expansion_filters=72, use_squeeze_and_excite=False, strides=1, activation='relu'))) self.assertEqual((2, 2), mobile_search_space_v3.get_strides( mobile_search_space_v3.DepthwiseBottleneckSpec( kernel_size=3, expansion_filters=72, use_squeeze_and_excite=False, strides=2, activation='relu')))
def test_get_strides_oneof(self): self.assertEqual( (1, 1), mobile_search_space_v3.get_strides( schema.OneOf([ mobile_search_space_v3.ConvSpec(kernel_size=3, strides=1), mobile_search_space_v3.ConvSpec(kernel_size=5, strides=1), ], basic_specs.OP_TAG))) self.assertEqual( (2, 2), mobile_search_space_v3.get_strides( schema.OneOf([ mobile_search_space_v3.ConvSpec(kernel_size=3, strides=2), mobile_search_space_v3.ConvSpec(kernel_size=5, strides=2), ], basic_specs.OP_TAG))) with self.assertRaisesRegex(ValueError, 'Stride mismatch'): mobile_search_space_v3.get_strides( schema.OneOf([ mobile_search_space_v3.ConvSpec(kernel_size=3, strides=1), mobile_search_space_v3.ConvSpec(kernel_size=3, strides=2), ], basic_specs.OP_TAG))
def test_get_strides_global_avg_pool(self): self.assertEqual((None, None), mobile_search_space_v3.get_strides( mobile_search_space_v3.GlobalAveragePoolSpec()))
def test_get_strides_zero(self): self.assertEqual( (1, 1), mobile_search_space_v3.get_strides(basic_specs.ZeroSpec()))
def test_get_strides_activation(self): self.assertEqual( (1, 1), mobile_search_space_v3.get_strides(mobile_search_space_v3.RELU))