Ejemplo n.º 1
0
def _freeze_attr(val: Any) -> Any:
    if isinstance(val, (dict, FrozenDict)):
        return FrozenDict({k: _freeze_attr(v) for k, v in val.items()})
    elif isinstance(val, (list, tuple)):
        return tuple(_freeze_attr(v) for v in val)
    else:
        return val
Ejemplo n.º 2
0
    def test_param_selection(self):
        params = {
            'x': {
                'kernel': 1,
                'bias': 2,
                'y': {
                    'kernel': 3,
                    'bias': 4,
                },
                'z': {},
            },
        }
        expected_params = {
            'x': {
                'kernel': 2,
                'bias': 2,
                'y': {
                    'kernel': 6,
                    'bias': 4,
                },
                'z': {}
            },
        }
        names = []

        def filter_fn(name, _):
            names.append(name)  # track names passed to filter_fn for testing
            return 'kernel' in name

        traversal = optim.ModelParamTraversal(filter_fn)

        # Model
        model = nn.Model(None, params)
        values = list(traversal.iterate(model))
        configs = [
            (nn.Model(None, params), nn.Model(None, expected_params)),
            (params, expected_params),
            (FrozenDict(params), FrozenDict(expected_params)),
        ]
        for model, expected_model in configs:
            self.assertEqual(values, [1, 3])
            self.assertEqual(
                set(names),
                set(['/x/kernel', '/x/bias', '/x/y/kernel', '/x/y/bias']))
            new_model = traversal.update(lambda x: x + x, model)
            self.assertEqual(new_model, expected_model)
Ejemplo n.º 3
0
def _freeze_attr(val: Any) -> Any:
  if isinstance(val, (dict, FrozenDict)):
    return FrozenDict({k: _freeze_attr(v) for k, v in val.items()})
  elif isinstance(val, tuple):
    # Special case namedtuples and special JAX tuple structures otherwise they
    # would be downgraded to normal tuples.
    if hasattr(val, '_fields') or type(val).__name__ == 'PartitionSpec':
      return type(val)(*[_freeze_attr(v) for v in val])
    else:
      return tuple(_freeze_attr(v) for v in val)
  elif isinstance(val, list):
    return tuple(_freeze_attr(v) for v in val)
  else:
    return val
Ejemplo n.º 4
0
def set_frozen_dict(frozen_dict: FrozenDict, key: str, value: Any) -> FrozenDict:
    unfrozen_dict = frozen_dict.unfreeze()
    unfrozen_dict[key] = value
    return FrozenDict(**unfrozen_dict)