def test_nested_fields(self): config = ml_collections.ConfigDict({'parent': {'x': 1}}) config.child = config_schema_utils.make_reference(config, 'parent') # In this case, 'config.child.x' is a reference to 'config.parent.x', but # note that 'config.child' is NOT a reference to 'config.parent'! self.assertEqual(config.parent.x, 1) self.assertEqual(config.child.x, 1) config.parent.x = 2 self.assertEqual(config.parent.x, 2) self.assertEqual(config.child.x, 2) config.parent = ml_collections.ConfigDict({'x': 3}) self.assertEqual(config.parent.x, 3) # In this case, config.parent is a new Python object unrelated to the old # config.parent. Since config.child is a reference to the old config.parent, # it has no connection to the new config.parent. self.assertEqual(config.child.x, 2) # However, this works as intended since the 'update' function assigns new # values to existing leaf nodes, preserving the reference structure between # parent and child internal nodes. Using this syntax is recommended for # updating many fields at once. config = ml_collections.ConfigDict({'parent': {'x': 1, 'y': 'hello'}}) config.child = config_schema_utils.make_reference(config, 'parent') config.parent.update({'x': 3, 'y': 'goodbye'}) self.assertEqual(config.parent.x, 3) self.assertEqual(config.parent.y, 'goodbye') self.assertEqual(config.child.x, 3) self.assertEqual(config.child.y, 'goodbye')
def get_config(num_blocks, use_auto_acts): """Returns a ConfigDict instance for a Imagenet (Resnet50 and Resnet101). The ConfigDict is wired up so that changing a field at one level of the hierarchy changes the value of that field everywhere downstream in the hierarchy. For example, changing the top-level 'prec' parameter (eg, config.prec=4) will cause the precision of all layers to change. Changing the precision of a specific layer type (eg, config.residual_block.conv_1.weight_prec=4) will cause the weight prec of all conv_1 layers to change, overriding the value of the global config.prec value. See config_schema_test.test_schema_matches_expected to see the structure of the ConfigDict instance this will return. Args: num_blocks: Number of residual blocks in the architecture. use_auto_acts: Whether to use automatic clipping bounds for activations or fixed bounds. Unlike other properties of the configuration which can be overridden directly in the ConfigDict instance, this affects the immutable schema of the ConfigDict and so has to be specified before the ConfigDict is created. Returns: A ConfigDict instance which parallels the hierarchy of TrainingHParams. """ base_config = get_base_config(use_auto_acts=use_auto_acts) model_hparams = base_config.model_hparams config_schema_utils.set_default_reference(model_hparams, base_config, "dense_layer") config_schema_utils.set_default_reference(model_hparams, base_config, "conv_init", parent_field="conv") model_hparams.residual_blocks = [ config_schema_utils.make_reference(base_config, "residual") for _ in range(num_blocks) ] model_hparams.update({ # Controls the number of parameters in the model by multiplying the number # of conv filters in each layer by this number. "filter_multiplier": float_ph(), "act_function": str_ph(), "se_ratio": float_ph(), "init_group": int_ph(), # feature group in the second group conv layer }) config_schema_utils.set_default_reference(model_hparams, base_config, "act_function", parent_field="act_function") base_config.lock() return base_config
def test_scalar_fields(self): config = ml_collections.ConfigDict({'parent_field': 1}) config.child_field = config_schema_utils.make_reference( config, 'parent_field') # 'child_field' is a reference to 'parent_field'. Changes to # 'parent_field' propagate to 'child_field'. self.assertEqual(config.parent_field, 1) self.assertEqual(config.child_field, 1) config.parent_field = 2 self.assertEqual(config.parent_field, 2) self.assertEqual(config.child_field, 2) # But changes to 'child_field' to NOT propagate back up to # 'parent_field'. config.child_field = 3 self.assertEqual(config.parent_field, 2) self.assertEqual(config.child_field, 3) config.parent_field = 4 self.assertEqual(config.parent_field, 4) # Reference is broken after 'child_field' was overridden earlier. self.assertEqual(config.child_field, 3)