def test_prune_model_spec_with_path_dropout_rate_tensor(self):
        model_spec = {
            'op1':
            schema.OneOf([
                mobile_search_space_v3.ConvSpec(kernel_size=2, strides=2),
                basic_specs.ZeroSpec(),
            ], basic_specs.OP_TAG),
            'op2':
            schema.OneOf([
                mobile_search_space_v3.ConvSpec(kernel_size=3, strides=4),
            ], basic_specs.OP_TAG),
            'filter':
            schema.OneOf([32], basic_specs.FILTERS_TAG),
        }

        model_spec = search_space_utils.prune_model_spec(
            model_spec, {basic_specs.OP_TAG: [0, 0]},
            path_dropout_rate=tf.constant(2.0) / tf.constant(10.0),
            training=True)

        self.assertCountEqual(model_spec.keys(), ['op1', 'op2', 'filter'])
        self.assertEqual(model_spec['op1'].mask.shape, tf.TensorShape([1]))
        self.assertIsNone(model_spec['op2'].mask)
        self.assertIsNone(model_spec['filter'].mask)

        # The value should either be 0 or 1 / (1 - path_dropout_rate) = 1.25
        op_mask_value = self.evaluate(model_spec['op1'].mask)
        self.assertTrue(
            abs(op_mask_value - 0) < 1e-6 or abs(op_mask_value - 1.25) < 1e-6,
            msg='Unexpected op_mask_value: {}'.format(op_mask_value))
    def test_get_strides_convolution(self):
        self.assertEqual((1, 1),
                         mobile_search_space_v3.get_strides(
                             mobile_search_space_v3.ConvSpec(kernel_size=3,
                                                             strides=1)))

        self.assertEqual((2, 2),
                         mobile_search_space_v3.get_strides(
                             mobile_search_space_v3.ConvSpec(kernel_size=3,
                                                             strides=2)))
    def test_prune_model_spec_with_path_dropout_training(self):
        model_spec = {
            'op1':
            schema.OneOf([
                mobile_search_space_v3.ConvSpec(kernel_size=2, strides=2),
                basic_specs.ZeroSpec(),
            ], basic_specs.OP_TAG),
            'op2':
            schema.OneOf([
                mobile_search_space_v3.ConvSpec(kernel_size=3, strides=4),
            ], basic_specs.OP_TAG),
            'filter':
            schema.OneOf([32], basic_specs.FILTERS_TAG),
        }

        model_spec = search_space_utils.prune_model_spec(
            model_spec, {basic_specs.OP_TAG: [0, 0]},
            path_dropout_rate=0.2,
            training=True)

        self.assertCountEqual(model_spec.keys(), ['op1', 'op2', 'filter'])
        self.assertEqual(model_spec['op1'].mask.shape, tf.TensorShape([1]))
        self.assertIsNone(model_spec['op2'].mask)
        self.assertIsNone(model_spec['filter'].mask)

        self.assertEqual(
            model_spec['op1'].choices,
            [mobile_search_space_v3.ConvSpec(kernel_size=2, strides=2)])
        self.assertEqual(
            model_spec['op2'].choices,
            [mobile_search_space_v3.ConvSpec(kernel_size=3, strides=4)])
        self.assertEqual(model_spec['filter'].choices, [32])

        self.assertEqual(model_spec['op1'].tag, basic_specs.OP_TAG)
        self.assertEqual(model_spec['op2'].tag, basic_specs.OP_TAG)
        self.assertEqual(model_spec['filter'].tag, basic_specs.FILTERS_TAG)

        op_mask_sum = 0
        for _ in range(100):
            # The value should either be 0 or 1 / (1 - path_dropout_rate) = 1.25
            op_mask_value = self.evaluate(model_spec['op1'].mask)
            self.assertTrue(
                abs(op_mask_value - 0) < 1e-6
                or abs(op_mask_value - 1.25) < 1e-6,
                msg='Unexpected op_mask_value: {}'.format(op_mask_value))
            op_mask_sum += op_mask_value[0]

        # The probability of this test failing by random chance is roughly 0.002%.
        # Our random number generators are deterministically seeded, so the test
        # shouldn't be flakey.
        self.assertGreaterEqual(op_mask_sum, 75)
        self.assertLessEqual(op_mask_sum, 113)
    def test_v3_zero_or_conv_with_child(self):
        kernel_size_mask = tf.placeholder(shape=[3], dtype=tf.float32)
        kernel_size = schema.OneOf([3, 5, 7], basic_specs.OP_TAG,
                                   kernel_size_mask)

        layer_mask = tf.placeholder(shape=[2], dtype=tf.float32)
        layer = schema.OneOf(choices=[
            basic_specs.ZeroSpec(),
            mobile_search_space_v3.ConvSpec(kernel_size=kernel_size,
                                            strides=1),
        ],
                             tag=basic_specs.OP_TAG,
                             mask=layer_mask)

        features = mobile_cost_model.coupled_tf_features(
            _make_single_layer_model(layer))

        with self.session() as sess:
            self.assertAllClose([1.0, 0.0, 0.0, 0.0],
                                sess.run(
                                    features, {
                                        layer_mask: [1, 0],
                                        kernel_size_mask: [1, 0, 0]
                                    }))
            self.assertAllClose([1.0, 0.0, 0.0, 0.0],
                                sess.run(
                                    features, {
                                        layer_mask: [1, 0],
                                        kernel_size_mask: [0, 1, 0]
                                    }))
            self.assertAllClose([1.0, 0.0, 0.0, 0.0],
                                sess.run(
                                    features, {
                                        layer_mask: [1, 0],
                                        kernel_size_mask: [0, 0, 1]
                                    }))
            self.assertAllClose([0.0, 1.0, 0.0, 0.0],
                                sess.run(
                                    features, {
                                        layer_mask: [0, 1],
                                        kernel_size_mask: [1, 0, 0]
                                    }))
            self.assertAllClose([0.0, 0.0, 1.0, 0.0],
                                sess.run(
                                    features, {
                                        layer_mask: [0, 1],
                                        kernel_size_mask: [0, 1, 0]
                                    }))
            self.assertAllClose([0.0, 0.0, 0.0, 1.0],
                                sess.run(
                                    features, {
                                        layer_mask: [0, 1],
                                        kernel_size_mask: [0, 0, 1]
                                    }))
    def test_prune_model_spec_with_path_dropout_eval(self):
        model_spec = {
            'op1':
            schema.OneOf([
                mobile_search_space_v3.ConvSpec(kernel_size=2, strides=2),
                basic_specs.ZeroSpec(),
            ], basic_specs.OP_TAG),
            'op2':
            schema.OneOf([
                mobile_search_space_v3.ConvSpec(kernel_size=3, strides=4),
            ], basic_specs.OP_TAG),
            'filter':
            schema.OneOf([32], basic_specs.FILTERS_TAG),
        }
        model_spec = search_space_utils.prune_model_spec(
            model_spec, {basic_specs.OP_TAG: [0, 0]},
            path_dropout_rate=0.2,
            training=False)

        self.assertCountEqual(model_spec.keys(), ['op1', 'op2', 'filter'])
        # Even though path_dropout_rate=0.2, the controller should not populate
        # the mask for op1 because we called prune_model_spec() with training=False.
        # In other words, path_dropout_rate should only affect the behavior during
        # training, not during evaluation.
        self.assertIsNone(model_spec['op1'].mask)
        self.assertIsNone(model_spec['op2'].mask)
        self.assertIsNone(model_spec['filter'].mask)

        self.assertEqual(model_spec['op1'].tag, basic_specs.OP_TAG)
        self.assertEqual(model_spec['op2'].tag, basic_specs.OP_TAG)
        self.assertEqual(model_spec['filter'].tag, basic_specs.FILTERS_TAG)

        self.assertEqual(
            model_spec['op1'].choices,
            [mobile_search_space_v3.ConvSpec(kernel_size=2, strides=2)])
        self.assertEqual(
            model_spec['op2'].choices,
            [mobile_search_space_v3.ConvSpec(kernel_size=3, strides=4)])
        self.assertEqual(model_spec['filter'].choices, [32])
    def test_get_strides_oneof(self):
        self.assertEqual(
            (1, 1),
            mobile_search_space_v3.get_strides(
                schema.OneOf([
                    mobile_search_space_v3.ConvSpec(kernel_size=3, strides=1),
                    mobile_search_space_v3.ConvSpec(kernel_size=5, strides=1),
                ], basic_specs.OP_TAG)))

        self.assertEqual(
            (2, 2),
            mobile_search_space_v3.get_strides(
                schema.OneOf([
                    mobile_search_space_v3.ConvSpec(kernel_size=3, strides=2),
                    mobile_search_space_v3.ConvSpec(kernel_size=5, strides=2),
                ], basic_specs.OP_TAG)))

        with self.assertRaisesRegex(ValueError, 'Stride mismatch'):
            mobile_search_space_v3.get_strides(
                schema.OneOf([
                    mobile_search_space_v3.ConvSpec(kernel_size=3, strides=1),
                    mobile_search_space_v3.ConvSpec(kernel_size=3, strides=2),
                ], basic_specs.OP_TAG))
    def test_v3_two_children(self):
        kernel_size1_mask = tf.placeholder(shape=[3], dtype=tf.float32)
        kernel_size1 = schema.OneOf([3, 5, 7], basic_specs.OP_TAG,
                                    kernel_size1_mask)

        kernel_size2_mask = tf.placeholder(shape=[2], dtype=tf.float32)
        kernel_size2 = schema.OneOf([3, 5], basic_specs.OP_TAG,
                                    kernel_size2_mask)

        layer_mask = tf.placeholder(shape=[2], dtype=tf.float32)
        layer = schema.OneOf(choices=[
            mobile_search_space_v3.SeparableConvSpec(kernel_size=kernel_size1,
                                                     strides=1),
            mobile_search_space_v3.ConvSpec(kernel_size=kernel_size2,
                                            strides=1),
        ],
                             tag=basic_specs.OP_TAG,
                             mask=layer_mask)

        features = mobile_cost_model.coupled_tf_features(
            _make_single_layer_model(layer))

        with self.session() as sess:
            self.assertAllClose(
                [1.0, 0.0, 0.0, 0.0, 0.0],
                sess.run(
                    features,
                    {
                        layer_mask: [1, 0],  # select the first mask
                        kernel_size1_mask: [1, 0, 0],
                        kernel_size2_mask: [1, 0]  # should be ignored
                    }))
            self.assertAllClose(
                [0.0, 1.0, 0.0, 0.0, 0.0],
                sess.run(
                    features,
                    {
                        layer_mask: [1, 0],  # select the first mask
                        kernel_size1_mask: [0, 1, 0],
                        kernel_size2_mask: [1, 0]  # should be ignored
                    }))
            self.assertAllClose(
                [0.0, 0.0, 1.0, 0.0, 0.0],
                sess.run(
                    features,
                    {
                        layer_mask: [1, 0],  # select the first mask
                        kernel_size1_mask: [0, 0, 1],
                        kernel_size2_mask: [1, 0]  # should be ignored
                    }))
            self.assertAllClose(
                [0.0, 0.0, 0.0, 1.0, 0.0],  # select the second mask
                sess.run(
                    features,
                    {
                        layer_mask: [0, 1],
                        kernel_size1_mask: [1, 0, 0],  # should be ignored
                        kernel_size2_mask: [1, 0]
                    }))
            self.assertAllClose(
                [0.0, 0.0, 0.0, 0.0, 1.0],  # select the second mask
                sess.run(
                    features,
                    {
                        layer_mask: [0, 1],
                        kernel_size1_mask: [1, 0, 0],  # should be ignored
                        kernel_size2_mask: [0, 1]
                    }))