Пример #1
0
    def test_to_fp32(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)

            # cast all params to fp16 and back to fp32
            params = model.to_fp16(model.params)
            params = model.to_fp32(params)

            # test if all params are in fp32
            types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
            for name, type_ in types.items():
                self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")

            # test masking
            flat_params = flatten_dict(params)
            key = random.choice(list(flat_params.keys()))  # choose a random param
            mask = {path: path != key for path in flat_params}  # don't cast the key
            mask = unflatten_dict(mask)

            # cast to fp16 and back to fp32 with mask
            params = model.to_fp16(model.params)
            params = model.to_fp32(params, mask)

            # test if all params are in fp32 except key
            types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
            for name, type_ in types.items():
                if name == key:
                    self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.")
                else:
                    self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
Пример #2
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)

        batch_size, sequence_length = input_ids.shape
        position_ids = jnp.broadcast_to(
            jnp.arange(sequence_length)[None, :],
            (batch_size, sequence_length))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            position_ids,
            return_dict=False,
        )

        random_params = module_init_outputs["params"]
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
    def test_push_to_hub_in_organization(self):
        config = BertConfig(vocab_size=99,
                            hidden_size=32,
                            num_hidden_layers=5,
                            num_attention_heads=4,
                            intermediate_size=37)
        model = FlaxBertModel(config)
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(
                os.path.join(tmp_dir, "test-model-flax-org"),
                push_to_hub=True,
                use_auth_token=self._token,
                organization="valid_org",
            )

            new_model = FlaxBertModel.from_pretrained(
                "valid_org/test-model-flax-org")

            base_params = flatten_dict(unfreeze(model.params))
            new_params = flatten_dict(unfreeze(new_model.params))

            for key in base_params.keys():
                max_diff = (base_params[key] - new_params[key]).sum().item()
                self.assertLessEqual(max_diff,
                                     1e-3,
                                     msg=f"{key} not identical")
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        attention_mask = jnp.ones_like(input_ids)
        head_mask = jnp.ones(
            (self.config.num_hidden_layers, self.config.num_attention_heads))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(rngs,
                                         input_ids,
                                         attention_mask,
                                         token_type_ids,
                                         head_mask,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
Пример #5
0
    def test_save_load_to_base_pt(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        base_class = FLAX_MODEL_MAPPING[config.__class__]

        for model_class in self.all_model_classes:
            if model_class == base_class:
                continue

            model = model_class(config)
            base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))

            # convert Flax model to PyTorch model
            pt_model_class = getattr(transformers, model_class.__name__[4:])  # Skip the "Flax" at the beginning
            pt_model = pt_model_class(config).eval()
            pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)

            # check that all base model weights are loaded correctly
            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_model.save_pretrained(tmpdirname)
                base_model = base_class.from_pretrained(tmpdirname, from_pt=True)

                base_params = flatten_dict(unfreeze(base_model.params))

                for key in base_params_from_head.keys():
                    max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
    def test_save_load_to_base(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
        base_class = FLAX_MODEL_MAPPING[config.__class__]

        for model_class in self.all_model_classes:
            if model_class == base_class:
                continue

            model = model_class(config)
            base_params_from_head = flatten_dict(
                unfreeze(model.params[model.base_model_prefix]))

            # check that all base model weights are loaded correctly
            with tempfile.TemporaryDirectory() as tmpdirname:
                model.save_pretrained(tmpdirname)
                base_model = base_class.from_pretrained(tmpdirname)

                base_params = flatten_dict(unfreeze(base_model.params))

                for key in base_params_from_head.keys():
                    max_diff = (base_params[key] -
                                base_params_from_head[key]).sum().item()
                    self.assertLessEqual(max_diff,
                                         1e-3,
                                         msg=f"{key} not identical")
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        pixel_values = jnp.zeros(input_shape, dtype=self.dtype)

        params_rng, dropout_rng = jax.random.split(rng)
        dropout_rng, droppath_rng = jax.random.split(dropout_rng)
        rngs = {
            "params": params_rng,
            "dropout": dropout_rng,
            "droppath": droppath_rng
        }

        random_params = self.module.init(rngs, pixel_values,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(rngs,
                                         input_ids,
                                         attention_mask,
                                         position_ids,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
Пример #9
0
def load_pretrained(variables, url='', default_cfg=None, filter_fn=None):
    if not url:
        assert default_cfg is not None and default_cfg['url']
        url = default_cfg['url']
    state_dict = load_state_dict_from_url(url, transpose=True)

    source_params, source_state = split_state_dict(state_dict)
    if filter_fn is not None:
        # filter after split as we may have modified the split criteria (ie bn running vars)
        source_params = filter_fn(source_params)
        source_state = filter_fn(source_state)

    # FIXME better way to do this?
    var_unfrozen = unfreeze(variables)
    missing_keys = []
    flat_params = flatten_dict(var_unfrozen['params'])
    flat_param_keys = set()
    for k, v in flat_params.items():
        flat_k = '.'.join(k)
        if flat_k in source_params:
            assert flat_params[k].shape == v.shape
            flat_params[k] = source_params[flat_k]
        else:
            missing_keys.append(flat_k)
        flat_param_keys.add(flat_k)
    unexpected_keys = list(
        set(source_params.keys()).difference(flat_param_keys))
    params = freeze(unflatten_dict(flat_params))

    flat_state = flatten_dict(var_unfrozen['batch_stats'])
    flat_state_keys = set()
    for k, v in flat_state.items():
        flat_k = '.'.join(k)
        if flat_k in source_state:
            assert flat_state[k].shape == v.shape
            flat_state[k] = source_state[flat_k]
        else:
            missing_keys.append(flat_k)
        flat_state_keys.add(flat_k)
    unexpected_keys.extend(
        list(set(source_state.keys()).difference(flat_state_keys)))
    batch_stats = freeze(unflatten_dict(flat_state))

    if missing_keys:
        print(
            f' WARNING: {len(missing_keys)} keys missing while loading state_dict. {str(missing_keys)}'
        )
    if unexpected_keys:
        print(
            f' WARNING: {len(unexpected_keys)} unexpected keys found while loading state_dict. {str(unexpected_keys)}'
        )

    return dict(params=params, batch_stats=batch_stats)
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        encoder_input_shape, decoder_input_shape = input_shape

        # init input tensors
        input_ids = jnp.zeros(encoder_input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)

        batch_size, sequence_length = input_ids.shape
        position_ids = jnp.broadcast_to(
            jnp.arange(sequence_length)[None, :],
            (batch_size, sequence_length))

        decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
        if not decoder_batch_size == batch_size:
            raise ValueError(
                f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
            )
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(decoder_sequence_length)[None, :],
            (decoder_batch_size, decoder_sequence_length))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
            position_ids,
            decoder_position_ids,
        )["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
Пример #11
0
 def decay_mask_fn(params):
     flat_params = traverse_util.flatten_dict(params)
     layer_norm_params = [
         (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
     ]
     flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
     return traverse_util.unflatten_dict(flat_mask)
Пример #12
0
def set_partitions(num_partitions, in_dict):
    rules = _get_partition_rules(num_partitions)
    replace = _replacement_rules(rules)
    initd = {k: _unmatched for k in flatten_dict(in_dict)}
    result = {k: replace(k, v) for k, v in initd.items()}
    assert _unmatched not in result.values(), 'Incomplete partition spec.'
    return unflatten_dict(result)
Пример #13
0
    def _cast_floating_to(self,
                          params: Union[Dict, FrozenDict],
                          dtype: jnp.dtype,
                          mask: Any = None) -> Any:
        """
        Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
        """

        # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
        def conditional_cast(param):
            if isinstance(param, jnp.ndarray) and jnp.issubdtype(
                    param.dtype, jnp.floating):
                param = param.astype(dtype)
            return param

        if mask is None:
            return jax.tree_map(conditional_cast, params)

        flat_params = flatten_dict(params)
        flat_mask, _ = jax.tree_flatten(mask)

        for masked, key in zip(flat_mask, flat_params.keys()):
            if masked:
                param = flat_params[key]
                flat_params[key] = conditional_cast(param)

        return unflatten_dict(flat_params)
Пример #14
0
def update_GCNN_parity(params):
    """Adds biases of parity-flip layers to the corresponding no-flip layers.
    Corrects for changes in GCNN_parity due to PR #1030 in NetKet 3.3.

    Args:
        params: a parameter pytree
    """
    # unfreeze just in case, doesn't break with a plain dict
    params = flatten_dict(unfreeze(params))
    to_remove = []
    for path in params:
        if (
            len(path) > 1
            and path[-2].startswith("equivariant_layers_flip")
            and path[-1] == "bias"
        ):
            alt_path = (
                *path[:-2],
                path[-2].replace("equivariant_layers_flip", "equivariant_layers"),
                path[-1],
            )
            params[alt_path] = params[alt_path] + params[path]
            to_remove.append(path)
    for path in to_remove:
        del params[path]
    return unflatten_dict(params)
Пример #15
0
 def decay_mask_fn(params):
     flat_params = traverse_util.flatten_dict(params)
     flat_mask = {
         path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale"))
         for path in flat_params
     }
     return traverse_util.unflatten_dict(flat_mask)
Пример #16
0
    def __init__(
        self,
        config: PretrainedConfig,
        module: nn.Module,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
    ):
        if config is None:
            raise ValueError("config cannot be None")

        if module is None:
            raise ValueError("module cannot be None")

        # Those are private to be exposed as typed property on derived classes.
        self._config = config
        self._module = module

        # Those are public as their type is generic to every derived classes.
        self.key = PRNGKey(seed)
        self.dtype = dtype

        # randomely initialized parameters
        random_params = self.init(self.key, input_shape)

        # save required_params as set
        self._required_params = set(
            flatten_dict(unfreeze(random_params)).keys())
        self.params = random_params
Пример #17
0
 def test_flatten_dict(self):
     xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
     flat_xs = traverse_util.flatten_dict(xs)
     self.assertEqual(flat_xs, {
         ('foo', ): 1,
         ('bar', 'a'): 2,
     })
Пример #18
0
 def decay_mask_fn(params):
     flat_params = traverse_util.flatten_dict(params)
     flat_mask = {
         path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
         for path in flat_params
     }
     return traverse_util.unflatten_dict(flat_mask)
Пример #19
0
 def weight_decay_mask(params):
     params = traverse_util.flatten_dict(params)
     mask = {
         k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale"))
         for k, v in params.items()
     }
     return traverse_util.unflatten_dict(mask)
Пример #20
0
def get_suffix_module_pairs(module_tree) -> List[Tuple[str, Type["Module"]]]:
    """Helper for naming pytrees of submodules."""
    if isinstance(module_tree, Module):
        return [('', module_tree)]
    else:
        flat_tree = traverse_util.flatten_dict(
            serialization.to_state_dict(module_tree))
        return [('_' + '_'.join(k), v) for k, v in flat_tree.items()]
Пример #21
0
def _get_suffix_value_pairs(
        tree_or_leaf: Any) -> List[Tuple[str, Type["Module"]]]:
    """Helper for naming pytrees of submodules."""
    dict_or_leaf = serialization.to_state_dict(tree_or_leaf)
    if dict_or_leaf == {} or not isinstance(dict_or_leaf, dict):
        return [('', tree_or_leaf)]
    else:
        flat_dict = traverse_util.flatten_dict(dict_or_leaf)
        return [('_' + '_'.join(k), v) for k, v in flat_dict.items()]
Пример #22
0
 def params(self, params: Union[Dict, FrozenDict]):
     if isinstance(params, FrozenDict):
         params = unfreeze(params)
     param_keys = set(flatten_dict(params).keys())
     if len(self.required_params - param_keys) > 0:
         raise ValueError(
             "Some parameters are missing. Make sure that `params` include the following "
             f"parameters {self.required_params - param_keys}")
     self._params = freeze(params)
Пример #23
0
def partition_nested_dict(d, flat_left_keys):
    left, right = {}, {}
    flat_left_keys = set(flat_left_keys)
    for k, v in flatten_dict(d).items():
        if k in flat_left_keys:
            left[k] = v
        else:
            right[k] = v
    return tuple(map(unflatten_dict, (left, right)))
Пример #24
0
def dict_replace(col, target, leaf_only=True):
    col_flat = flatten_dict(unfreeze(col))
    diff = {}
    for keys_flat in col_flat.keys():
        for tar_key, tar_val in target.items():
            if (keys_flat[-1] == tar_key if leaf_only else
                (tar_key in keys_flat)):
                diff[keys_flat] = tar_val
    col_flat.update(diff)
    col = unflatten_dict(col_flat)
    return col
Пример #25
0
def _map_over_modules_in_tree(fn, tree_or_leaf):
  """Helper for mapping function over submodules."""
  dict_or_leaf = serialization.to_state_dict(tree_or_leaf)
  if not isinstance(dict_or_leaf, dict) or dict_or_leaf == {}:
    return fn('', tree_or_leaf)
  else:
    flat_dict = traverse_util.flatten_dict(dict_or_leaf, keep_empty_nodes=True)
    mapped_flat_dict = {k: fn('_' + '_'.join(k), v)
                        for k, v in _sorted_items(flat_dict)}
    return serialization.from_state_dict(
        tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict))
Пример #26
0
 def test_flatten_dict_keep_empty(self):
     xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
     flat_xs = traverse_util.flatten_dict(xs, keep_empty_nodes=True)
     self.assertEqual(
         flat_xs, {
             ('foo', ): 1,
             ('bar', 'a'): 2,
             ('bar', 'b'): traverse_util.empty_node,
         })
     xs_restore = traverse_util.unflatten_dict(flat_xs)
     self.assertEqual(xs, xs_restore)
Пример #27
0
    def test_default_params_dtype(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            # check if all params are still in float32 when dtype of computation is half-precision
            model = model_class(config, dtype=jnp.float16)
            types = jax.tree_map(lambda x: x.dtype, model.params)
            types = flatten_dict(types)

            for name, type_ in types.items():
                self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")
Пример #28
0
  def restore_state(self, state_dict):
    """Restore parameter and optimizer state from state dictionary.

    Adapted from
    https://github.com/google-research/t5x/blob/main/t5x/optimizers.py. Includes
    support to handle `optax.EmptyState`.

    Args:
      state_dict: Contains desired new parameters and optimizer state

    Returns:
      Updated train state.
    """
    params = serialization.from_state_dict(self.params, state_dict["params"])

    # Get all the possible keys in the reference optimizer state.
    flat_ref_opt_state_dict = traverse_util.flatten_dict(
        serialization.to_state_dict(self.opt_state),
        keep_empty_nodes=True,
        sep="/")

    flat_src_opt_state_dict = dict(
        traverse_util.flatten_dict(state_dict["opt_state"], sep="/"))
    # Adding the empty paths back to flat_src_opt_state_dict.
    for k, v in flat_ref_opt_state_dict.items():
      if k in flat_src_opt_state_dict:
        continue
      # The key is not in the input state dict, presumably because it
      # corresponds to an empty dict.
      if v != traverse_util.empty_node:
        raise ValueError(
            f"Failed to restore optimizer state, path {k} is not present "
            "in the input optimizer state dict.")
      flat_src_opt_state_dict[k] = v

    # Restore state from the enhanced state dict.
    opt_state = serialization.from_state_dict(
        self.opt_state,
        traverse_util.unflatten_dict(flat_src_opt_state_dict, sep="/"))
    return self.replace(params=params, opt_state=opt_state)
Пример #29
0
 def test_flatten_dict_is_leaf(self):
     xs = {'foo': {'c': 4}, 'bar': {'a': 2, 'b': {}}}
     flat_xs = traverse_util.flatten_dict(
         xs, is_leaf=lambda k, x: len(k) == 1 and len(x) == 2)
     self.assertEqual(flat_xs, {
         ('foo', 'c'): 4,
         ('bar', ): {
             'a': 2,
             'b': {}
         },
     })
     xs_restore = traverse_util.unflatten_dict(flat_xs)
     self.assertEqual(xs, xs_restore)
Пример #30
0
    def __init__(
        self,
        config: PretrainedConfig,
        module: nn.Module,
        input_shape: Tuple = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
    ):
        if config is None:
            raise ValueError("config cannot be None")

        if module is None:
            raise ValueError("module cannot be None")

        # Those are private to be exposed as typed property on derived classes.
        self._config = config
        self._module = module

        # Those are public as their type is generic to every derived classes.
        self.key = PRNGKey(seed)
        self.dtype = dtype
        self.input_shape = input_shape

        # To check if the model was intialized automatically.
        self._is_initialized = _do_init

        if _do_init:
            # randomly initialized parameters
            random_params = self.init_weights(self.key, input_shape)
            params_shape_tree = jax.eval_shape(lambda params: params,
                                               random_params)
        else:
            init_fn = partial(self.init_weights, input_shape=input_shape)
            params_shape_tree = jax.eval_shape(init_fn, self.key)

            logger.info(
                "Model weights are not initialized as `_do_init` is set to `False`. "
                f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
            )

        # get the shape of the parameters
        self._params_shape_tree = params_shape_tree

        # save required_params as set
        self._required_params = set(
            flatten_dict(unfreeze(params_shape_tree)).keys())

        # initialize the parameters
        if _do_init:
            self.params = random_params