def test_frozen_dict_reduce(self): before = FrozenDict(a=FrozenDict(b=1, c=2)) cl, data = before.__reduce__() after = cl(*data) self.assertIsNot(before, after) self.assertEqual(before, after) self.assertEqual(after, {'a': {'b': 1, 'c': 2}})
def test_frozen_dict_repr(self): expected = ("""FrozenDict({ a: 1, b: { c: 2, d: {}, }, })""") xs = FrozenDict({'a': 1, 'b': {'c': 2, 'd': {}}}) self.assertEqual(repr(xs), expected) self.assertEqual(repr(FrozenDict()), 'FrozenDict({})')
def _resnet(rng, arch, block, layers, pretrained, **kwargs): resnet = ResNet(block=block, layers=layers, **kwargs) if pretrained: torch_params = utils.load_torch_params(model_urls[arch]) flax_params = FrozenDict(utils.torch_to_linen(torch_params, _get_flax_keys)) else: init_batch = jnp.ones((1, 224, 224, 3), jnp.float32) flax_params = ResNet(block=block, layers=layers, **kwargs).init(rng, init_batch) return resnet, flax_params
def inception_v3(rng, pretrained=True, **kwargs): inception = Inception(**kwargs) if pretrained: torch_params = utils.load_torch_params(model_urls['inception_v3']) flax_params = FrozenDict( utils.torch_to_linen(torch_params, _get_flax_keys)) else: init_batch = jnp.ones((1, 299, 299, 3), jnp.float32) flax_params = Inception(**kwargs).init(rng, init_batch) return inception, flax_params
def _vgg(rng, arch, cfg, batch_norm, pretrained, **kwargs): vgg = VGG(cfg=cfgs[cfg], batch_norm=batch_norm, **kwargs) if pretrained: torch_params = utils.load_torch_params(model_urls[arch]) flax_params = FrozenDict( _torch_to_vgg(torch_params, cfgs[cfg], batch_norm)) else: init_batch = jnp.ones((1, 224, 224, 3), jnp.float32) flax_params = VGG(cfg=cfgs[cfg], batch_norm=batch_norm, **kwargs).init(rng, init_batch) return vgg, flax_params
def apply( self, params: tp.Any, states: tp.Any, training: bool, rng: types.RNGSeq, ) -> tp.Callable[..., types.OutputStates]: if params is None: params = FrozenDict() if states is None: states = FrozenDict() def _lambda(*args, **kwargs): def apply_fn(*args, **kwargs): variables = dict(params=params, **states) return self.module.apply( variables, *args, rngs={"params": rng.next()}, mutable=True, **kwargs, ) y_pred, variables = utils.inject_dependencies( apply_fn, signature_f=self.module.__call__, )( *args, **kwargs, ) net_states, net_params = (variables.pop("params") if "params" in variables else (variables, FrozenDict())) return types.OutputStates(y_pred, net_params, net_states) return _lambda
def _load_model(rng, arch_type, backbone, pretrained, num_classes, **kwargs): model = _make_model(rng, arch_type, backbone, num_classes, **kwargs) if pretrained: arch = arch_type + '_' + backbone if arch not in model_urls: raise NotImplementedError('pretrained {} is not supported'.format(arch)) else: get_flax_keys_fn = segm_heads[arch_type][1] torch_params = utils.load_torch_params(model_urls[arch]) flax_params = FrozenDict(utils.torch_to_linen(torch_params, get_flax_keys_fn)) else: init_batch = jnp.ones((1, 224, 224, 3), jnp.float32) flax_params = model.init(rng, init_batch) return model, flax_params
def _torch_to_vgg(torch_params, cfg, batch_norm=False): """Convert PyTorch parameters to nested dictionaries.""" flax_params = { 'params': { 'backbone': {}, 'classifier': {} }, 'batch_stats': { 'backbone': {} } } conv_idx = 0 bn_idx = 0 tensor_iter = iter(torch_params.items()) def next_tensor(): _, tensor = next(tensor_iter) return tensor.detach().numpy() for layer_cfg in cfg: if isinstance(layer_cfg, int): flax_params['params']['backbone'][f'Conv_{conv_idx}'] = { 'kernel': np.transpose(next_tensor(), (2, 3, 1, 0)), 'bias': next_tensor(), } conv_idx += 1 if batch_norm: flax_params['params']['backbone'][f'BatchNorm_{bn_idx}'] = { 'scale': next_tensor(), 'bias': next_tensor(), } flax_params['batch_stats']['backbone'][ f'BatchNorm_{bn_idx}'] = { 'mean': next_tensor(), 'var': next_tensor(), } bn_idx += 1 for idx in range(3): flax_params['params']['classifier'][f'Dense_{idx}'] = { 'kernel': np.transpose(next_tensor()), 'bias': next_tensor(), } return FrozenDict(flax_params)
def _lambda(*args, **kwargs): def init_fn(*args, **kwargs): return self.module.init_with_output(rng.next(), *args, **kwargs) y_pred, variables = utils.inject_dependencies( init_fn, signature_f=self.module.__call__, )( *args, **kwargs, ) assert isinstance(variables, FrozenDict) net_states, net_params = (variables.pop("params") if "params" in variables else (variables, FrozenDict())) return types.OutputStates(y_pred, net_params, net_states)
def _lambda(*args, **kwargs): def apply_fn(*args, **kwargs): variables = dict(params=params, **states) return self.module.apply( variables, *args, rngs={"params": rng.next()}, mutable=True, **kwargs, ) y_pred, variables = utils.inject_dependencies( apply_fn, signature_f=self.module.__call__, )( *args, **kwargs, ) net_states, net_params = (variables.pop("params") if "params" in variables else (variables, FrozenDict())) return types.OutputStates(y_pred, net_params, net_states)
def test_frozen_dict_maps(self): xs = {'a': 1, 'b': {'c': 2}} frozen = FrozenDict(xs) frozen2 = jax.tree_map(lambda x: x + x, frozen) self.assertEqual(unfreeze(frozen2), {'a': 2, 'b': {'c': 4}})
def test_frozen_dict_pop(self): xs = {'a': 1, 'b': {'c': 2}} b, a = FrozenDict(xs).pop('a') self.assertEqual(a, 1) self.assertEqual(unfreeze(b), {'b': {'c': 2}})
def test_frozen_dict_copy_reserved_name(self): result = FrozenDict({'a': 1}).copy({'cls': 2}) self.assertEqual(result, {'a': 1, 'cls': 2})