Ejemplo n.º 1
0
 def test_frozen_dict_reduce(self):
   before = FrozenDict(a=FrozenDict(b=1, c=2))
   cl, data = before.__reduce__()
   after = cl(*data)
   self.assertIsNot(before, after)
   self.assertEqual(before, after)
   self.assertEqual(after, {'a': {'b': 1, 'c': 2}})
Ejemplo n.º 2
0
    def test_frozen_dict_repr(self):
        expected = ("""FrozenDict({
    a: 1,
    b: {
        c: 2,
        d: {},
    },
})""")

        xs = FrozenDict({'a': 1, 'b': {'c': 2, 'd': {}}})
        self.assertEqual(repr(xs), expected)
        self.assertEqual(repr(FrozenDict()), 'FrozenDict({})')
Ejemplo n.º 3
0
def _resnet(rng, arch, block, layers, pretrained, **kwargs):
  resnet = ResNet(block=block, layers=layers, **kwargs)

  if pretrained:
    torch_params = utils.load_torch_params(model_urls[arch])
    flax_params = FrozenDict(utils.torch_to_linen(torch_params, _get_flax_keys))
  else:
    init_batch = jnp.ones((1, 224, 224, 3), jnp.float32)
    flax_params = ResNet(block=block, layers=layers, **kwargs).init(rng, init_batch)

  return resnet, flax_params
Ejemplo n.º 4
0
def inception_v3(rng, pretrained=True, **kwargs):
    inception = Inception(**kwargs)

    if pretrained:
        torch_params = utils.load_torch_params(model_urls['inception_v3'])
        flax_params = FrozenDict(
            utils.torch_to_linen(torch_params, _get_flax_keys))
    else:
        init_batch = jnp.ones((1, 299, 299, 3), jnp.float32)
        flax_params = Inception(**kwargs).init(rng, init_batch)

    return inception, flax_params
Ejemplo n.º 5
0
def _vgg(rng, arch, cfg, batch_norm, pretrained, **kwargs):
    vgg = VGG(cfg=cfgs[cfg], batch_norm=batch_norm, **kwargs)

    if pretrained:
        torch_params = utils.load_torch_params(model_urls[arch])
        flax_params = FrozenDict(
            _torch_to_vgg(torch_params, cfgs[cfg], batch_norm))
    else:
        init_batch = jnp.ones((1, 224, 224, 3), jnp.float32)
        flax_params = VGG(cfg=cfgs[cfg], batch_norm=batch_norm,
                          **kwargs).init(rng, init_batch)

    return vgg, flax_params
Ejemplo n.º 6
0
    def apply(
        self,
        params: tp.Any,
        states: tp.Any,
        training: bool,
        rng: types.RNGSeq,
    ) -> tp.Callable[..., types.OutputStates]:
        if params is None:
            params = FrozenDict()

        if states is None:
            states = FrozenDict()

        def _lambda(*args, **kwargs):
            def apply_fn(*args, **kwargs):
                variables = dict(params=params, **states)
                return self.module.apply(
                    variables,
                    *args,
                    rngs={"params": rng.next()},
                    mutable=True,
                    **kwargs,
                )

            y_pred, variables = utils.inject_dependencies(
                apply_fn,
                signature_f=self.module.__call__,
            )(
                *args,
                **kwargs,
            )

            net_states, net_params = (variables.pop("params")
                                      if "params" in variables else
                                      (variables, FrozenDict()))

            return types.OutputStates(y_pred, net_params, net_states)

        return _lambda
Ejemplo n.º 7
0
def _load_model(rng, arch_type, backbone, pretrained, num_classes, **kwargs):
  model = _make_model(rng, arch_type, backbone, num_classes, **kwargs)

  if pretrained:
    arch = arch_type + '_' + backbone
    if arch not in model_urls:
      raise NotImplementedError('pretrained {} is not supported'.format(arch))
    else:
      get_flax_keys_fn = segm_heads[arch_type][1]
      torch_params = utils.load_torch_params(model_urls[arch])
      flax_params = FrozenDict(utils.torch_to_linen(torch_params, get_flax_keys_fn))
  else:
    init_batch = jnp.ones((1, 224, 224, 3), jnp.float32)
    flax_params = model.init(rng, init_batch)

  return model, flax_params
Ejemplo n.º 8
0
def _torch_to_vgg(torch_params, cfg, batch_norm=False):
    """Convert PyTorch parameters to nested dictionaries."""
    flax_params = {
        'params': {
            'backbone': {},
            'classifier': {}
        },
        'batch_stats': {
            'backbone': {}
        }
    }
    conv_idx = 0
    bn_idx = 0

    tensor_iter = iter(torch_params.items())

    def next_tensor():
        _, tensor = next(tensor_iter)
        return tensor.detach().numpy()

    for layer_cfg in cfg:
        if isinstance(layer_cfg, int):
            flax_params['params']['backbone'][f'Conv_{conv_idx}'] = {
                'kernel': np.transpose(next_tensor(), (2, 3, 1, 0)),
                'bias': next_tensor(),
            }
            conv_idx += 1

            if batch_norm:
                flax_params['params']['backbone'][f'BatchNorm_{bn_idx}'] = {
                    'scale': next_tensor(),
                    'bias': next_tensor(),
                }
                flax_params['batch_stats']['backbone'][
                    f'BatchNorm_{bn_idx}'] = {
                        'mean': next_tensor(),
                        'var': next_tensor(),
                    }
                bn_idx += 1

    for idx in range(3):
        flax_params['params']['classifier'][f'Dense_{idx}'] = {
            'kernel': np.transpose(next_tensor()),
            'bias': next_tensor(),
        }

    return FrozenDict(flax_params)
Ejemplo n.º 9
0
        def _lambda(*args, **kwargs):
            def init_fn(*args, **kwargs):
                return self.module.init_with_output(rng.next(), *args,
                                                    **kwargs)

            y_pred, variables = utils.inject_dependencies(
                init_fn,
                signature_f=self.module.__call__,
            )(
                *args,
                **kwargs,
            )
            assert isinstance(variables, FrozenDict)

            net_states, net_params = (variables.pop("params")
                                      if "params" in variables else
                                      (variables, FrozenDict()))

            return types.OutputStates(y_pred, net_params, net_states)
Ejemplo n.º 10
0
        def _lambda(*args, **kwargs):
            def apply_fn(*args, **kwargs):
                variables = dict(params=params, **states)
                return self.module.apply(
                    variables,
                    *args,
                    rngs={"params": rng.next()},
                    mutable=True,
                    **kwargs,
                )

            y_pred, variables = utils.inject_dependencies(
                apply_fn,
                signature_f=self.module.__call__,
            )(
                *args,
                **kwargs,
            )

            net_states, net_params = (variables.pop("params")
                                      if "params" in variables else
                                      (variables, FrozenDict()))

            return types.OutputStates(y_pred, net_params, net_states)
Ejemplo n.º 11
0
 def test_frozen_dict_maps(self):
     xs = {'a': 1, 'b': {'c': 2}}
     frozen = FrozenDict(xs)
     frozen2 = jax.tree_map(lambda x: x + x, frozen)
     self.assertEqual(unfreeze(frozen2), {'a': 2, 'b': {'c': 4}})
Ejemplo n.º 12
0
 def test_frozen_dict_pop(self):
     xs = {'a': 1, 'b': {'c': 2}}
     b, a = FrozenDict(xs).pop('a')
     self.assertEqual(a, 1)
     self.assertEqual(unfreeze(b), {'b': {'c': 2}})
Ejemplo n.º 13
0
 def test_frozen_dict_copy_reserved_name(self):
   result = FrozenDict({'a': 1}).copy({'cls': 2})
   self.assertEqual(result, {'a': 1, 'cls': 2})