Example #1
0
    def test_independent_sample_hierarchical(self):
        structure = schema.OneOf([
            schema.OneOf(['a', 'b', 'c'], basic_specs.OP_TAG),
            schema.OneOf(['d', 'e', 'f', 'g'], basic_specs.OP_TAG),
        ], basic_specs.OP_TAG)
        rl_structure, dist_info = controller.independent_sample(
            structure,
            increase_ops_probability=0,
            increase_filters_probability=0,
            hierarchical=True)

        tensors = {
            'outer_mask': rl_structure.mask,
            'entropy': dist_info['entropy'],
            'sample_log_prob': dist_info['sample_log_prob'],
        }

        self.evaluate(tf.global_variables_initializer())
        for _ in range(10):
            values = self.evaluate(tensors)
            if np.all(values['outer_mask'] == np.array([1, 0])):
                self.assertAlmostEqual(values['entropy'],
                                       math.log(2) + math.log(3))
                self.assertAlmostEqual(values['sample_log_prob'],
                                       math.log(1 / 2) + math.log(1 / 3))
            elif np.all(values['outer_mask'] == np.array([0, 1])):
                self.assertAlmostEqual(values['entropy'],
                                       math.log(2) + math.log(4))
                self.assertAlmostEqual(values['sample_log_prob'],
                                       math.log(1 / 2) + math.log(1 / 4))
            else:
                self.fail('Unexpected outer_mask: %s', values['outer_mask'])
Example #2
0
    def bneck(s, skippable):
        """Construct a spec for an inverted bottleneck layer."""
        possible_filter_multipliers = [3.0, 6.0]
        possible_kernel_sizes = [3, 5, 7]
        choices = []

        if collapse_shared_ops:
            kernel_size = schema.OneOf(possible_kernel_sizes,
                                       basic_specs.OP_TAG)
            expansion_filters = schema.OneOf([
                basic_specs.FilterMultiplier(multiplier)
                for multiplier in possible_filter_multipliers
            ], basic_specs.FILTERS_TAG)
            choices.append(
                DepthwiseBottleneckSpec(kernel_size=kernel_size,
                                        expansion_filters=expansion_filters,
                                        use_squeeze_and_excite=False,
                                        strides=s,
                                        activation=RELU))
        else:
            for multiplier in possible_filter_multipliers:
                for kernel_size in possible_kernel_sizes:
                    choices.append(
                        DepthwiseBottleneckSpec(
                            kernel_size=kernel_size,
                            expansion_filters=basic_specs.FilterMultiplier(
                                multiplier),
                            use_squeeze_and_excite=False,
                            strides=s,
                            activation=RELU))

        if skippable:
            choices.append(basic_specs.ZeroSpec())
        return schema.OneOf(choices, basic_specs.OP_TAG)
Example #3
0
def _build_model(params,
                 model_spec):
  """Translate a ConvTowerSpec namedtuple into a rematlib Layer."""
  input_filters = schema.OneOf([3], basic_specs.FILTERS_TAG)
  layer = None
  result = []
  endpoints = []
  for block_spec in model_spec.blocks:
    for layer_spec in block_spec.layers:
      if isinstance(layer_spec, mobile_search_space_v3.DetectionEndpointSpec):
        if layer is None:
          raise ValueError(
              'The first layer of the network cannot be a detection endpoint.')
        endpoints.append(layer)
      else:
        output_filters = block_spec.filters
        if isinstance(output_filters, int):
          output_filters = schema.OneOf(
              [output_filters], basic_specs.FILTERS_TAG)

        layer = _build_layer(
            params, layer_spec, input_filters, output_filters,
            model_spec.filters_base)
        input_filters = output_filters
        result.append(layer)

  result.append(_make_head(params))

  # Build the model
  model = layers.Sequential(result, aux_outputs=endpoints)
  return model
    def test_prune_model_spec_with_path_dropout_rate_tensor(self):
        model_spec = {
            'op1':
            schema.OneOf([
                mobile_search_space_v3.ConvSpec(kernel_size=2, strides=2),
                basic_specs.ZeroSpec(),
            ], basic_specs.OP_TAG),
            'op2':
            schema.OneOf([
                mobile_search_space_v3.ConvSpec(kernel_size=3, strides=4),
            ], basic_specs.OP_TAG),
            'filter':
            schema.OneOf([32], basic_specs.FILTERS_TAG),
        }

        model_spec = search_space_utils.prune_model_spec(
            model_spec, {basic_specs.OP_TAG: [0, 0]},
            path_dropout_rate=tf.constant(2.0) / tf.constant(10.0),
            training=True)

        self.assertCountEqual(model_spec.keys(), ['op1', 'op2', 'filter'])
        self.assertEqual(model_spec['op1'].mask.shape, tf.TensorShape([1]))
        self.assertIsNone(model_spec['op2'].mask)
        self.assertIsNone(model_spec['filter'].mask)

        # The value should either be 0 or 1 / (1 - path_dropout_rate) = 1.25
        op_mask_value = self.evaluate(model_spec['op1'].mask)
        self.assertTrue(
            abs(op_mask_value - 0) < 1e-6 or abs(op_mask_value - 1.25) < 1e-6,
            msg='Unexpected op_mask_value: {}'.format(op_mask_value))
 def test_tf_indices_with_masks(self):
     model_spec = [
         schema.OneOf([1], 'foo', mask=tf.constant([1])),
         schema.OneOf([2, 3], 'bar', mask=tf.constant([0, 1])),
         schema.OneOf([4, 5, 6], 'baz', mask=tf.constant([0, 0, 1])),
     ]
     indices = search_space_utils.tf_indices(model_spec)
     self.assertAllEqual(self.evaluate(indices), [0, 1, 2])
    def test_get_mask(self):
        with self.assertRaisesWithPredicateMatch(ValueError,
                                                 'OneOf must have a mask'):
            cost_model_lib.get_mask(
                schema.OneOf([6, 7, 8, 9], basic_specs.OP_TAG))

        mask = cost_model_lib.get_mask(
            schema.OneOf([6, 7, 8, 9], basic_specs.OP_TAG,
                         tf.constant([0, 0, 1, 0])))
        self.assertAllEqual(self.evaluate(mask), [0, 0, 1, 0])
    def test_v3_zero_or_conv_with_child(self):
        kernel_size_mask = tf.placeholder(shape=[3], dtype=tf.float32)
        kernel_size = schema.OneOf([3, 5, 7], basic_specs.OP_TAG,
                                   kernel_size_mask)

        layer_mask = tf.placeholder(shape=[2], dtype=tf.float32)
        layer = schema.OneOf(choices=[
            basic_specs.ZeroSpec(),
            mobile_search_space_v3.ConvSpec(kernel_size=kernel_size,
                                            strides=1),
        ],
                             tag=basic_specs.OP_TAG,
                             mask=layer_mask)

        features = mobile_cost_model.coupled_tf_features(
            _make_single_layer_model(layer))

        with self.session() as sess:
            self.assertAllClose([1.0, 0.0, 0.0, 0.0],
                                sess.run(
                                    features, {
                                        layer_mask: [1, 0],
                                        kernel_size_mask: [1, 0, 0]
                                    }))
            self.assertAllClose([1.0, 0.0, 0.0, 0.0],
                                sess.run(
                                    features, {
                                        layer_mask: [1, 0],
                                        kernel_size_mask: [0, 1, 0]
                                    }))
            self.assertAllClose([1.0, 0.0, 0.0, 0.0],
                                sess.run(
                                    features, {
                                        layer_mask: [1, 0],
                                        kernel_size_mask: [0, 0, 1]
                                    }))
            self.assertAllClose([0.0, 1.0, 0.0, 0.0],
                                sess.run(
                                    features, {
                                        layer_mask: [0, 1],
                                        kernel_size_mask: [1, 0, 0]
                                    }))
            self.assertAllClose([0.0, 0.0, 1.0, 0.0],
                                sess.run(
                                    features, {
                                        layer_mask: [0, 1],
                                        kernel_size_mask: [0, 1, 0]
                                    }))
            self.assertAllClose([0.0, 0.0, 0.0, 1.0],
                                sess.run(
                                    features, {
                                        layer_mask: [0, 1],
                                        kernel_size_mask: [0, 0, 1]
                                    }))
    def test_prune_simple_model_spec_validation_no_tags(self):
        model_spec = [
            schema.OneOf(['a', 'b', 'c'], basic_specs.OP_TAG),
            schema.OneOf(['z', 'w'], basic_specs.FILTERS_TAG),
        ]

        with self.assertRaisesWithLiteralMatch(
                ValueError,
                'Genotype contains 1 oneofs but model_spec contains 2'):
            search_space_utils.prune_model_spec(model_spec=model_spec,
                                                genotype=[0],
                                                prune_filters_by_value=True)
    def test_prune_model_spec_with_path_dropout_training(self):
        model_spec = {
            'op1':
            schema.OneOf([
                mobile_search_space_v3.ConvSpec(kernel_size=2, strides=2),
                basic_specs.ZeroSpec(),
            ], basic_specs.OP_TAG),
            'op2':
            schema.OneOf([
                mobile_search_space_v3.ConvSpec(kernel_size=3, strides=4),
            ], basic_specs.OP_TAG),
            'filter':
            schema.OneOf([32], basic_specs.FILTERS_TAG),
        }

        model_spec = search_space_utils.prune_model_spec(
            model_spec, {basic_specs.OP_TAG: [0, 0]},
            path_dropout_rate=0.2,
            training=True)

        self.assertCountEqual(model_spec.keys(), ['op1', 'op2', 'filter'])
        self.assertEqual(model_spec['op1'].mask.shape, tf.TensorShape([1]))
        self.assertIsNone(model_spec['op2'].mask)
        self.assertIsNone(model_spec['filter'].mask)

        self.assertEqual(
            model_spec['op1'].choices,
            [mobile_search_space_v3.ConvSpec(kernel_size=2, strides=2)])
        self.assertEqual(
            model_spec['op2'].choices,
            [mobile_search_space_v3.ConvSpec(kernel_size=3, strides=4)])
        self.assertEqual(model_spec['filter'].choices, [32])

        self.assertEqual(model_spec['op1'].tag, basic_specs.OP_TAG)
        self.assertEqual(model_spec['op2'].tag, basic_specs.OP_TAG)
        self.assertEqual(model_spec['filter'].tag, basic_specs.FILTERS_TAG)

        op_mask_sum = 0
        for _ in range(100):
            # The value should either be 0 or 1 / (1 - path_dropout_rate) = 1.25
            op_mask_value = self.evaluate(model_spec['op1'].mask)
            self.assertTrue(
                abs(op_mask_value - 0) < 1e-6
                or abs(op_mask_value - 1.25) < 1e-6,
                msg='Unexpected op_mask_value: {}'.format(op_mask_value))
            op_mask_sum += op_mask_value[0]

        # The probability of this test failing by random chance is roughly 0.002%.
        # Our random number generators are deterministically seeded, so the test
        # shouldn't be flakey.
        self.assertGreaterEqual(op_mask_sum, 75)
        self.assertLessEqual(op_mask_sum, 113)
 def test_prune_model_spec_prune_filters_by_value_with_invalid_value(self):
     model_spec = {
         'op': schema.OneOf(['a', 'b', 'c'], basic_specs.OP_TAG),
         'filters': schema.OneOf([128, 256, 512], basic_specs.FILTERS_TAG),
     }
     genotype = {
         basic_specs.OP_TAG: [1],
         basic_specs.FILTERS_TAG: [1024],  # Use values instead of indices
     }
     with self.assertRaises(ValueError):
         search_space_utils.prune_model_spec(model_spec,
                                             genotype,
                                             prune_filters_by_value=True)
    def test_prune_simple_model_spec_no_tags(self):
        pruned_spec = search_space_utils.prune_model_spec(
            model_spec=[
                schema.OneOf(['a', 'b', 'c'], basic_specs.OP_TAG),
                schema.OneOf(['z', 'w'], basic_specs.FILTERS_TAG),
            ],
            genotype=[2, 1],
            prune_filters_by_value=True)

        self.assertEqual(pruned_spec, [
            schema.OneOf(['c'], basic_specs.OP_TAG),
            schema.OneOf(['w'], basic_specs.FILTERS_TAG),
        ])
    def test_map_oenofs_with_tuple_paths_trivial(self):
        structure = schema.OneOf([1, 2], 'tag')

        all_paths = []
        all_oneofs = []

        def visit(path, oneof):
            all_paths.append(path)
            all_oneofs.append(oneof)
            return schema.OneOf([x * 10 for x in oneof.choices], oneof.tag)

        self.assertEqual(schema.map_oneofs_with_tuple_paths(visit, structure),
                         schema.OneOf([10, 20], 'tag'))
        self.assertEqual(all_paths, [()])
        self.assertEqual(all_oneofs, [schema.OneOf([1, 2], 'tag')])
def _from_json(structure):
    """Converted a pure JSON data structure to one with namedtuples and OneOfs."""
    if structure is None or isinstance(structure, _PRIMITIVE_TYPES):
        return structure
    elif isinstance(structure, list):
        assert structure
        typename = structure[0]
        structure = structure[1:]

        if typename == 'dict':
            return {_from_json(k): _from_json(v) for (k, v) in structure}
        elif typename.startswith('namedtuple:'):
            cls = namedtuple_name_to_class(typename[len('namedtuple:'):])
            kv_pairs = [(_from_json(k), _from_json(v)) for (k, v) in structure]
            return _namedtuple_from_json(cls, kv_pairs)
        elif typename == 'oneof':
            keys = tuple(_from_json(k) for (k, v) in structure)
            assert keys == ('choices', 'tag'), keys
            return schema.OneOf(*(_from_json(v) for (k, v) in structure))
        elif typename == 'list':
            return list(map(_from_json, structure))
        elif typename == 'tuple':
            return tuple(map(_from_json, structure))
        else:
            raise ValueError('Unsupported __type: {}'.format(typename))
    else:
        raise ValueError('Unrecognized JSON type: {}'.format(type(structure)))
Example #14
0
    def test_independent_sample_temperature(self):
        structure = schema.OneOf(['foo', 'bar', 'baz'], basic_specs.OP_TAG)
        temperature = tf.placeholder_with_default(tf.constant(5.0, tf.float32),
                                                  shape=(),
                                                  name='temperature')
        rl_structure, dist_info = controller.independent_sample(
            structure, temperature=temperature)

        with self.cached_session() as sess:
            sess.run(tf.global_variables_initializer())

            # Samples should be valid even when the temperature is set to a value
            # other than 1.
            self.assertOneHot(sess.run(rl_structure.mask))

            # Before training, the sample log-probability and entropy shouldn't be
            # affected by the temperature, since the probabilities are initialized
            # to a uniform distribution.
            self.assertAlmostEqual(sess.run(dist_info['sample_log_prob']),
                                   math.log(1 / 3))
            self.assertAlmostEqual(sess.run(dist_info['entropy']), math.log(3))

            # The gradients should be multiplied by (1 / temperature).
            # The OneOf has three possible choices. The gradient for the selected one
            # will be positive, while the gradients for the other two will be
            # negative. Since the selected choice can change between steps, we compare
            # the max, which should always give us gradients w.r.t. the selected one.
            trainable_vars = tf.trainable_variables()
            self.assertLen(trainable_vars, 1)

            grad_tensors = tf.gradients(dist_info['sample_log_prob'],
                                        trainable_vars)
            grad1 = np.max(sess.run(grad_tensors[0], {temperature: 1.0}))
            grad5 = np.max(sess.run(grad_tensors[0], {temperature: 5.0}))
            self.assertAlmostEqual(grad1 / 5, grad5)
 def test_v3_single_choice_zero_only(self):
     layer = schema.OneOf(choices=[basic_specs.ZeroSpec()],
                          tag=basic_specs.OP_TAG,
                          mask=tf.constant([1.0]))
     features = mobile_cost_model.coupled_tf_features(
         _make_single_layer_model(layer))
     self.assertAllClose(self.evaluate(features), [1.0])
    def test_tf_argmax_or_zero(self):
        # Pruned OneOf structure without mask
        oneof = schema.OneOf([1], 'foo')
        self.assertEqual(
            self.evaluate(search_space_utils.tf_argmax_or_zero(oneof)), 0)

        # Unpruned OneOf structure without mask
        with self.assertRaisesWithPredicateMatch(ValueError,
                                                 'Expect pruned structure'):
            oneof = schema.OneOf([1, 2], 'foo')
            search_space_utils.tf_argmax_or_zero(oneof)

        # Unpruned OneOf structure with mask
        oneof = schema.OneOf([1, 2], 'foo', tf.constant([0.0, 1.0]))
        self.assertEqual(
            self.evaluate(search_space_utils.tf_argmax_or_zero(oneof)), 1)
Example #17
0
    def test_serialization_with_simple_structures(self):
        # Primitives.
        self._run_serialization_test(None)
        self._run_serialization_test(1)
        self._run_serialization_test(0.5)
        self._run_serialization_test(1.0)
        self._run_serialization_test('foo')

        # Lists and tuples.
        self._run_serialization_test([1, 2, 3])
        self._run_serialization_test((1, 2, 3))

        # Dictionaries.
        self._run_serialization_test({'a': 3, 'b': 4})
        self._run_serialization_test({10: 'x', 20: 'y'})
        self._run_serialization_test({(1, 2): 'x', (3, 4): 'y'})

        # Namedtuples
        self._run_serialization_test(NamedTuple1(42),
                                     expected_type=NamedTuple1)
        self._run_serialization_test(NamedTuple2(12345),
                                     expected_type=NamedTuple2)

        # OneOf nodes.
        self._run_serialization_test(schema.OneOf((1, 2, 3), 'tag'))
    def update_spec(oneof):
        """Visit a schema.OneOf node in `model_spec`, return an updated value."""
        if genotype_is_dict and oneof.tag not in genotype:
            return oneof

        if genotype_is_dict:
            selection = genotype[oneof.tag].pop(0)
            if oneof.tag == basic_specs.FILTERS_TAG and prune_filters_by_value:
                selection = oneof.choices.index(selection)
        else:
            selection = genotype.pop(0)

        # If an operation is skippable (i.e., it can be replaced with a ZeroSpec)
        # then we optionally apply path dropout during stand-alone training.
        # This logic, if enabled, will replace a standard RL controller.
        mask = None
        if (path_dropout_rate != 0.0 and training
                and oneof.tag == basic_specs.OP_TAG
                and zero_spec in oneof.choices):
            keep_prob = 1.0 - path_dropout_rate
            # Mask is [1] with probability `keep_prob`, and [0] otherwise.
            mask = tf.cast(tf.less(tf.random_uniform([1]), keep_prob),
                           tf.float32)
            # Normalize the mask so that the expected value of each element 1.
            mask = mask / keep_prob

        return schema.OneOf([oneof.choices[selection]], oneof.tag, mask)
    def test_prune_model_spec_prune_filters_by_value(self):
        model_spec = {
            'op': schema.OneOf(['a', 'b', 'c'], basic_specs.OP_TAG),
            'filters': schema.OneOf([128, 256, 512], basic_specs.FILTERS_TAG),
        }
        genotype = {
            basic_specs.OP_TAG: [1],
            basic_specs.FILTERS_TAG: [512],  # Use values instead of indices
        }
        pruned_spec = search_space_utils.prune_model_spec(
            model_spec, genotype, prune_filters_by_value=True)

        expected_spec = {
            'op': schema.OneOf(['b'], basic_specs.OP_TAG),
            'filters': schema.OneOf([512], basic_specs.FILTERS_TAG),
        }
        self.assertEqual(pruned_spec, expected_spec)
Example #20
0
  def test_independent_sample_increase_ops_probability_1(self):
    structure = schema.OneOf(['foo', 'bar', 'baz'], basic_specs.OP_TAG)
    rl_structure, dist_info = controller.independent_sample(
        structure, increase_ops_probability=1.0)

    self.evaluate(tf.global_variables_initializer())
    self.assertAllClose(self.evaluate(rl_structure.mask), [1/3, 1/3, 1/3])
    self.assertEqual(self.evaluate(dist_info['sample_log_prob']), 0)
Example #21
0
 def sepconv(s):
     choices = []
     for kernel_size in (3, 5, 7):
         choices.append(
             SeparableConvSpec(kernel_size=kernel_size,
                               strides=s,
                               activation=RELU))
     return schema.OneOf(choices, basic_specs.OP_TAG)
Example #22
0
 def update(oneof):
     if oneof.tag == basic_specs.OP_TAG:
         mask = random_one_hot(len(oneof.choices))
         return schema.OneOf(choices=oneof.choices,
                             tag=oneof.tag,
                             mask=mask)
     else:
         return oneof
Example #23
0
    def test_independent_sample_increase_ops_does_not_affect_ops(self):
        structure = schema.OneOf([42, 64], basic_specs.OP_TAG)
        rl_structure, dist_info = controller.independent_sample(
            structure, increase_filters_probability=1.0)

        self.evaluate(tf.global_variables_initializer())
        self.assertOneHot(self.evaluate(rl_structure.mask))
        self.assertAlmostEqual(self.evaluate(dist_info['sample_log_prob']),
                               math.log(1 / 2))
Example #24
0
    def test_independent_sample_increase_filters_probability_0(self):
        structure = schema.OneOf([4, 12, 8], basic_specs.FILTERS_TAG)
        rl_structure, dist_info = controller.independent_sample(
            structure, increase_filters_probability=0.0)

        self.evaluate(tf.global_variables_initializer())
        self.assertOneHot(self.evaluate(rl_structure.mask))
        self.assertAlmostEqual(self.evaluate(dist_info['sample_log_prob']),
                               math.log(1 / 3))
Example #25
0
    def test_independent_sample_increase_filters_probability_1_big_space(self):
        # Use a large enough number of choices that we're unlikely to select the
        # right one by random chance.
        structure = schema.OneOf(list(range(100)), basic_specs.FILTERS_TAG)
        rl_structure, dist_info = controller.independent_sample(
            structure, increase_filters_probability=1.0)

        self.evaluate(tf.global_variables_initializer())
        self.assertAllClose(self.evaluate(rl_structure.mask), [0] * 99 + [1])
        self.assertEqual(self.evaluate(dist_info['sample_log_prob']), 0)
Example #26
0
    def test_independent_sample_increase_filters_probability_1(self):
        # Make sure that increase_filters does the right thing when the choices do
        # not appear in sorted order.
        structure = schema.OneOf([4, 12, 8], basic_specs.FILTERS_TAG)
        rl_structure, dist_info = controller.independent_sample(
            structure, increase_filters_probability=1.0)

        self.evaluate(tf.global_variables_initializer())
        self.assertAllClose(self.evaluate(rl_structure.mask), [0, 1, 0])
        self.assertEqual(self.evaluate(dist_info['sample_log_prob']), 0)
Example #27
0
 def bneck(input_size, se, s, act):
     """Construct a DepthwiseBottleneckSpec namedtuple."""
     if use_relative_expansion_filters:
         expansion_filters = sorted({
             basic_specs.FilterMultiplier(expansion)
             for expansion in expansion_multipliers
         })
     else:
         expansion_filters = sorted({
             search_space_utils.scale_filters(input_size, expansion, base=8)
             for expansion in expansion_multipliers
         })
     if search_squeeze_and_excite:
         # Replace the default value of the argument 'se' with a OneOf node.
         se = schema.OneOf([False, True], basic_specs.OP_TAG)
     return DepthwiseBottleneckSpec(
         kernel_size=schema.OneOf([3, 5, 7], basic_specs.OP_TAG),
         expansion_filters=choose_filters(expansion_filters),
         use_squeeze_and_excite=se,
         strides=s,
         activation=act)
Example #28
0
  def update_mask(path, oneof):
    """Add a one-hot mask to 'oneof' whose value is derived from 'indices'."""
    index = indices[State.position]
    State.position += 1

    if index < 0 or index >= len(oneof.choices):
      raise ValueError(
          'Invalid index: {:d} for path: {:s} with {:d} choices'.format(
              index, path, len(oneof.choices)))

    mask = tf.one_hot(index, len(oneof.choices))
    return schema.OneOf(oneof.choices, oneof.tag, mask)
    def test_map_oneofs_with_tuple_paths_containing_arrays_and_dicts(self):
        structure = {
            'foo': [
                schema.OneOf([1, 2], 'tag1'),
                schema.OneOf([3, 4, 5], 'tag2'),
            ]
        }

        all_paths = []
        all_oneofs = []

        def visit(path, oneof):
            all_paths.append(path)
            all_oneofs.append(oneof)
            return schema.OneOf([x * 10 for x in oneof.choices], oneof.tag)

        self.assertEqual(
            schema.map_oneofs_with_tuple_paths(visit, structure), {
                'foo': [
                    schema.OneOf([10, 20], 'tag1'),
                    schema.OneOf([30, 40, 50], 'tag2'),
                ]
            })
        self.assertEqual(all_paths, [
            ('foo', 0),
            ('foo', 1),
        ])
        self.assertEqual(all_oneofs, [
            schema.OneOf([1, 2], 'tag1'),
            schema.OneOf([3, 4, 5], 'tag2'),
        ])
    def test_map_oneofs(self):
        structure = {
            'foo': [
                schema.OneOf([1, 2], 'tag1'),
                schema.OneOf([3, 4, 5], 'tag2'),
            ]
        }

        all_oneofs = []

        def visit(oneof):
            all_oneofs.append(oneof)
            return schema.OneOf([x * 10 for x in oneof.choices], oneof.tag)

        self.assertEqual(
            schema.map_oneofs(visit, structure), {
                'foo': [
                    schema.OneOf([10, 20], 'tag1'),
                    schema.OneOf([30, 40, 50], 'tag2'),
                ]
            })
        self.assertEqual(all_oneofs, [
            schema.OneOf([1, 2], 'tag1'),
            schema.OneOf([3, 4, 5], 'tag2'),
        ])