Пример #1
0
def load_pretrained(variables, url='', default_cfg=None, filter_fn=None):
    if not url:
        assert default_cfg is not None and default_cfg['url']
        url = default_cfg['url']
    state_dict = load_state_dict_from_url(url, transpose=True)

    source_params, source_state = split_state_dict(state_dict)
    if filter_fn is not None:
        # filter after split as we may have modified the split criteria (ie bn running vars)
        source_params = filter_fn(source_params)
        source_state = filter_fn(source_state)

    # FIXME better way to do this?
    var_unfrozen = unfreeze(variables)
    missing_keys = []
    flat_params = flatten_dict(var_unfrozen['params'])
    flat_param_keys = set()
    for k, v in flat_params.items():
        flat_k = '.'.join(k)
        if flat_k in source_params:
            assert flat_params[k].shape == v.shape
            flat_params[k] = source_params[flat_k]
        else:
            missing_keys.append(flat_k)
        flat_param_keys.add(flat_k)
    unexpected_keys = list(
        set(source_params.keys()).difference(flat_param_keys))
    params = freeze(unflatten_dict(flat_params))

    flat_state = flatten_dict(var_unfrozen['batch_stats'])
    flat_state_keys = set()
    for k, v in flat_state.items():
        flat_k = '.'.join(k)
        if flat_k in source_state:
            assert flat_state[k].shape == v.shape
            flat_state[k] = source_state[flat_k]
        else:
            missing_keys.append(flat_k)
        flat_state_keys.add(flat_k)
    unexpected_keys.extend(
        list(set(source_state.keys()).difference(flat_state_keys)))
    batch_stats = freeze(unflatten_dict(flat_state))

    if missing_keys:
        print(
            f' WARNING: {len(missing_keys)} keys missing while loading state_dict. {str(missing_keys)}'
        )
    if unexpected_keys:
        print(
            f' WARNING: {len(unexpected_keys)} unexpected keys found while loading state_dict. {str(unexpected_keys)}'
        )

    return dict(params=params, batch_stats=batch_stats)
Пример #2
0
 def weight_decay_mask(params):
     params = traverse_util.flatten_dict(params)
     mask = {
         k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale"))
         for k, v in params.items()
     }
     return traverse_util.unflatten_dict(mask)
Пример #3
0
def update_GCNN_parity(params):
    """Adds biases of parity-flip layers to the corresponding no-flip layers.
    Corrects for changes in GCNN_parity due to PR #1030 in NetKet 3.3.

    Args:
        params: a parameter pytree
    """
    # unfreeze just in case, doesn't break with a plain dict
    params = flatten_dict(unfreeze(params))
    to_remove = []
    for path in params:
        if (
            len(path) > 1
            and path[-2].startswith("equivariant_layers_flip")
            and path[-1] == "bias"
        ):
            alt_path = (
                *path[:-2],
                path[-2].replace("equivariant_layers_flip", "equivariant_layers"),
                path[-1],
            )
            params[alt_path] = params[alt_path] + params[path]
            to_remove.append(path)
    for path in to_remove:
        del params[path]
    return unflatten_dict(params)
Пример #4
0
    def test_to_fp32(self):
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            model = model_class(config)

            # cast all params to fp16 and back to fp32
            params = model.to_fp16(model.params)
            params = model.to_fp32(params)

            # test if all params are in fp32
            types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
            for name, type_ in types.items():
                self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")

            # test masking
            flat_params = flatten_dict(params)
            key = random.choice(list(flat_params.keys()))  # choose a random param
            mask = {path: path != key for path in flat_params}  # don't cast the key
            mask = unflatten_dict(mask)

            # cast to fp16 and back to fp32 with mask
            params = model.to_fp16(model.params)
            params = model.to_fp32(params, mask)

            # test if all params are in fp32 except key
            types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
            for name, type_ in types.items():
                if name == key:
                    self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.")
                else:
                    self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
Пример #5
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)

        batch_size, sequence_length = input_ids.shape
        position_ids = jnp.broadcast_to(
            jnp.arange(sequence_length)[None, :],
            (batch_size, sequence_length))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            position_ids,
            return_dict=False,
        )

        random_params = module_init_outputs["params"]
        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(rngs,
                                         input_ids,
                                         attention_mask,
                                         position_ids,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        input_ids = jnp.zeros(input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        attention_mask = jnp.ones_like(input_ids)
        head_mask = jnp.ones(
            (self.config.num_hidden_layers, self.config.num_attention_heads))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(rngs,
                                         input_ids,
                                         attention_mask,
                                         token_type_ids,
                                         head_mask,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
Пример #8
0
 def decay_mask_fn(params):
     flat_params = traverse_util.flatten_dict(params)
     flat_mask = {
         path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
         for path in flat_params
     }
     return traverse_util.unflatten_dict(flat_mask)
Пример #9
0
def set_partitions(num_partitions, in_dict):
    rules = _get_partition_rules(num_partitions)
    replace = _replacement_rules(rules)
    initd = {k: _unmatched for k in flatten_dict(in_dict)}
    result = {k: replace(k, v) for k, v in initd.items()}
    assert _unmatched not in result.values(), 'Incomplete partition spec.'
    return unflatten_dict(result)
Пример #10
0
 def decay_mask_fn(params):
     flat_params = traverse_util.flatten_dict(params)
     flat_mask = {
         path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale"))
         for path in flat_params
     }
     return traverse_util.unflatten_dict(flat_mask)
Пример #11
0
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        pixel_values = jnp.zeros(input_shape, dtype=self.dtype)

        params_rng, dropout_rng = jax.random.split(rng)
        dropout_rng, droppath_rng = jax.random.split(dropout_rng)
        rngs = {
            "params": params_rng,
            "dropout": dropout_rng,
            "droppath": droppath_rng
        }

        random_params = self.module.init(rngs, pixel_values,
                                         return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
Пример #12
0
def merge_nested_dicts(*ds):
    merged = {}
    for d in map(flatten_dict, map(flax.core.unfreeze, ds)):
        if any(k in merged.keys() for k in d.keys()):
            raise ValueError('Key conflict!')
        merged.update(d)
    return unflatten_dict(merged)
Пример #13
0
 def decay_mask_fn(params):
     flat_params = traverse_util.flatten_dict(params)
     layer_norm_params = [
         (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
     ]
     flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
     return traverse_util.unflatten_dict(flat_mask)
Пример #14
0
 def test_unflatten_dict(self):
     flat_xs = {
         ('foo', ): 1,
         ('bar', 'a'): 2,
     }
     xs = traverse_util.unflatten_dict(flat_xs)
     self.assertEqual(xs, {'foo': 1, 'bar': {'a': 2}})
Пример #15
0
    def _cast_floating_to(self,
                          params: Union[Dict, FrozenDict],
                          dtype: jnp.dtype,
                          mask: Any = None) -> Any:
        """
        Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
        """

        # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
        def conditional_cast(param):
            if isinstance(param, jnp.ndarray) and jnp.issubdtype(
                    param.dtype, jnp.floating):
                param = param.astype(dtype)
            return param

        if mask is None:
            return jax.tree_map(conditional_cast, params)

        flat_params = flatten_dict(params)
        flat_mask, _ = jax.tree_flatten(mask)

        for masked, key in zip(flat_mask, flat_params.keys()):
            if masked:
                param = flat_params[key]
                flat_params[key] = conditional_cast(param)

        return unflatten_dict(flat_params)
Пример #16
0
def convert_state_dict_from_pt(model_class: ABC, state: Dict, config: PretrainedConfig):
    """
    Converts a PyTorch parameter state dict to an equivalent Flax parameter state dict
    """
    state = {k: v.numpy() for k, v in state.items()}
    state = model_class.convert_from_pytorch(state, config)
    state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
    return state
Пример #17
0
def _map_over_modules_in_tree(fn, tree_or_leaf):
  """Helper for mapping function over submodules."""
  dict_or_leaf = serialization.to_state_dict(tree_or_leaf)
  if not isinstance(dict_or_leaf, dict) or dict_or_leaf == {}:
    return fn('', tree_or_leaf)
  else:
    flat_dict = traverse_util.flatten_dict(dict_or_leaf, keep_empty_nodes=True)
    mapped_flat_dict = {k: fn('_' + '_'.join(k), v)
                        for k, v in _sorted_items(flat_dict)}
    return serialization.from_state_dict(
        tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict))
Пример #18
0
def dict_replace(col, target, leaf_only=True):
    col_flat = flatten_dict(unfreeze(col))
    diff = {}
    for keys_flat in col_flat.keys():
        for tar_key, tar_val in target.items():
            if (keys_flat[-1] == tar_key if leaf_only else
                (tar_key in keys_flat)):
                diff[keys_flat] = tar_val
    col_flat.update(diff)
    col = unflatten_dict(col_flat)
    return col
Пример #19
0
 def test_flatten_dict_keep_empty(self):
     xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
     flat_xs = traverse_util.flatten_dict(xs, keep_empty_nodes=True)
     self.assertEqual(
         flat_xs, {
             ('foo', ): 1,
             ('bar', 'a'): 2,
             ('bar', 'b'): traverse_util.empty_node,
         })
     xs_restore = traverse_util.unflatten_dict(flat_xs)
     self.assertEqual(xs, xs_restore)
Пример #20
0
 def test_flatten_dict_is_leaf(self):
     xs = {'foo': {'c': 4}, 'bar': {'a': 2, 'b': {}}}
     flat_xs = traverse_util.flatten_dict(
         xs, is_leaf=lambda k, x: len(k) == 1 and len(x) == 2)
     self.assertEqual(flat_xs, {
         ('foo', 'c'): 4,
         ('bar', ): {
             'a': 2,
             'b': {}
         },
     })
     xs_restore = traverse_util.unflatten_dict(flat_xs)
     self.assertEqual(xs, xs_restore)
Пример #21
0
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
    # convert pytorch tensor to numpy
    pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

    random_flax_state_dict = flatten_dict(flax_model.params)
    flax_state_dict = {}

    remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
        flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
    )
    add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
        flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
    )

    # Need to change some parameters name to match Flax names so that we don't have to fork any layer
    for pt_key, pt_tensor in pt_state_dict.items():

        pt_tuple_key = tuple(pt_key.split("."))

        has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
        require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict

        if remove_base_model_prefix and has_base_model_prefix:
            pt_tuple_key = pt_tuple_key[1:]
        elif add_base_model_prefix and require_base_model_prefix:
            pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key

        # Correctly rename weight parameters
        if pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
            pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
        if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
            pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
        elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
            pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
            pt_tensor = pt_tensor.T
        elif pt_tuple_key[-1] == "gamma":
            pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
        elif pt_tuple_key[-1] == "beta":
            pt_tuple_key = pt_tuple_key[:-1] + ("bias",)

        if pt_tuple_key in random_flax_state_dict:
            if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
                raise ValueError(
                    "PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
                )

        # also add unexpected weight so that warning is thrown
        flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)

    return unflatten_dict(flax_state_dict)
Пример #22
0
 def decay_mask_fn(params):
     flat_params = traverse_util.flatten_dict(params)
     # find out all LayerNorm parameters
     layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
     layer_norm_named_params = set(
         [
             layer[-2:]
             for layer_norm_name in layer_norm_candidates
             for layer in flat_params.keys()
             if layer_norm_name in "".join(layer).lower()
         ]
     )
     flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
     return traverse_util.unflatten_dict(flat_mask)
Пример #23
0
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
    import torch

    # Load the index
    flax_state_dict = {}
    for shard_file in shard_filenames:
        # load using msgpack utils
        pt_state_dict = torch.load(shard_file)
        pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

        model_prefix = flax_model.base_model_prefix
        random_flax_state_dict = flatten_dict(flax_model.params)

        load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
            model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
        )
        load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
            model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
        )
        # Need to change some parameters name to match Flax names
        for pt_key, pt_tensor in pt_state_dict.items():

            pt_tuple_key = tuple(pt_key.split("."))

            # remove base model prefix if necessary
            has_base_model_prefix = pt_tuple_key[0] == model_prefix
            if load_model_with_head_into_base_model and has_base_model_prefix:
                pt_tuple_key = pt_tuple_key[1:]

            # Correctly rename weight parameters
            flax_key, flax_tensor = rename_key_and_reshape_tensor(
                pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
            )
            # add model prefix if necessary
            require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
            if load_base_model_into_model_with_head and require_base_model_prefix:
                flax_key = (model_prefix,) + flax_key

            if flax_key in random_flax_state_dict:
                if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
                    raise ValueError(
                        f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
                        f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
                    )

            # also add unexpected weight so that warning is thrown
            flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
    return unflatten_dict(flax_state_dict)
    def init_weights(self,
                     rng: jax.random.PRNGKey,
                     input_shape: Tuple,
                     params: FrozenDict = None) -> FrozenDict:
        encoder_input_shape, decoder_input_shape = input_shape

        # init input tensors
        input_ids = jnp.zeros(encoder_input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)
        decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)

        batch_size, sequence_length = input_ids.shape
        position_ids = jnp.broadcast_to(
            jnp.arange(sequence_length)[None, :],
            (batch_size, sequence_length))

        decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
        if not decoder_batch_size == batch_size:
            raise ValueError(
                f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
            )
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(decoder_sequence_length)[None, :],
            (decoder_batch_size, decoder_sequence_length))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
            position_ids,
            decoder_position_ids,
        )["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params
def update_optimizer(optimizer, t5_data):
    """Update flax optimizer for T5 model."""
    optimizer_data = traverse_util.flatten_dict(optimizer.state_dict())
    optimizer_data = {'/'.join(k): v for k, v in optimizer_data.items()}
    # Shape check.
    for k, v in jax.tree_map(np.shape, t5_data).items():
        if np.shape(optimizer_data[k]) != v:
            raise ValueError(
                f'Variable {k} has shape {v} != {np.shape(optimizer_data[k])}')
    # Dtype check template optimizer against imported T5 arrays.
    for k, v in jax.tree_map(lambda x: x.dtype, t5_data).items():
        if optimizer_data[k].dtype != v:
            raise ValueError(
                f'Variable {k} has dtype {v} != {optimizer_data[k].dtype}')
    optimizer_data = t5_data
    optimizer_data = traverse_util.unflatten_dict(
        {tuple(k.split('/')): v
         for k, v in optimizer_data.items()})
    return optimizer.restore_state(optimizer_data)
Пример #26
0
def update_dense_symm(params, names=["dense_symm", "Dense"]):
    """Updates DenseSymm kernels in pre-PR#1030 parameter pytrees to the new
    3D convention.

    Args:
        params: a parameter pytree
        names: layer names search for, default: those used in RBMSymm and GCNN*
    """
    params = unfreeze(params)  # just in case, doesn't break with a plain dict

    def fix_one_kernel(args):
        path, array = args
        if (len(path) > 1 and path[-2] in names and path[-1] == "kernel"
                and array.ndim == 2):
            array = jnp.expand_dims(array, 1)
        return (path, array)

    return unflatten_dict(
        dict(map(fix_one_kernel,
                 flatten_dict(params).items())))
Пример #27
0
 def __call__(self, features: Features) -> Features:
     features = traverse_util.flatten_dict(features)
     for name in list(features):
         dtype = features[name].dtype
         if dtype not in self.types:
             del features[name]
             logging.warning(
                 "Removing feature %r because dtype %s is not supported in JAX.",
                 name, dtype)
         elif isinstance(features[name], tf.SparseTensor):
             del features[name]
             logging.warning(
                 "Removing feature %r because sparse tensors are not "
                 "supported in JAX.", name)
         elif isinstance(features[name], tf.RaggedTensor):
             del features[name]
             logging.warning(
                 "Removing feature %r because ragged tensors are not support in "
                 "JAX.", name)
     features = traverse_util.unflatten_dict(features)
     return features  # pytype: disable=bad-return-type
Пример #28
0
  def restore_state(self, state_dict):
    """Restore parameter and optimizer state from state dictionary.

    Adapted from
    https://github.com/google-research/t5x/blob/main/t5x/optimizers.py. Includes
    support to handle `optax.EmptyState`.

    Args:
      state_dict: Contains desired new parameters and optimizer state

    Returns:
      Updated train state.
    """
    params = serialization.from_state_dict(self.params, state_dict["params"])

    # Get all the possible keys in the reference optimizer state.
    flat_ref_opt_state_dict = traverse_util.flatten_dict(
        serialization.to_state_dict(self.opt_state),
        keep_empty_nodes=True,
        sep="/")

    flat_src_opt_state_dict = dict(
        traverse_util.flatten_dict(state_dict["opt_state"], sep="/"))
    # Adding the empty paths back to flat_src_opt_state_dict.
    for k, v in flat_ref_opt_state_dict.items():
      if k in flat_src_opt_state_dict:
        continue
      # The key is not in the input state dict, presumably because it
      # corresponds to an empty dict.
      if v != traverse_util.empty_node:
        raise ValueError(
            f"Failed to restore optimizer state, path {k} is not present "
            "in the input optimizer state dict.")
      flat_src_opt_state_dict[k] = v

    # Restore state from the enhanced state dict.
    opt_state = serialization.from_state_dict(
        self.opt_state,
        traverse_util.unflatten_dict(flat_src_opt_state_dict, sep="/"))
    return self.replace(params=params, opt_state=opt_state)
Пример #29
0
    def from_pretrained(cls,
                        pretrained_model_name_or_path: Union[str, os.PathLike],
                        dtype: jnp.dtype = jnp.float32,
                        *model_args,
                        **kwargs):
        r"""
        Instantiate a pretrained flax model from a pre-trained model configuration.

        The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
        task.

        The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
        weights are discarded.

        Parameters:
            pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
                Can be either:

                    - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
                      Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
                      a user or organization name, like ``dbmdz/bert-base-german-cased``.
                    - A path to a `directory` containing model weights saved using
                      :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
                    - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
                      case, ``from_pt`` should be set to :obj:`True`.
            model_args (sequence of positional arguments, `optional`):
                All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
            config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
                Can be either:

                    - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
                    - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.

                Configuration for the model to use instead of an automatically loaded configuation. Configuration can
                be automatically loaded when:

                    - The model is a model provided by the library (loaded with the `model id` string of a pretrained
                      model).
                    - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
                      by supplying the save directory.
                    - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
                      configuration JSON file named `config.json` is found in the directory.
            cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
                standard cache should not be used.
            from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Load the model weights from a PyTorch checkpoint save file (see docstring of
                ``pretrained_model_name_or_path`` argument).
            force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to delete incompletely received files. Will attempt to resume the download if such a
                file exists.
            proxies (:obj:`Dict[str, str], `optional`):
                A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to only look at local files (i.e., do not try to download the model).
            revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
                git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
                identifier allowed by git.
            kwargs (remaining dictionary of keyword arguments, `optional`):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
                automatically loaded:

                    - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
                      underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
                      already been done)
                    - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
                      initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
                      ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
                      with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
                      attribute will be passed to the underlying model's ``__init__`` function.

        Examples::

            >>> from transformers import BertConfig, FlaxBertModel
            >>> # Download model and configuration from huggingface.co and cache.
            >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
            >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
            >>> model = FlaxBertModel.from_pretrained('./test/saved_model/')
            >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
            >>> config = BertConfig.from_json_file('./pt_model/config.json')
            >>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config)
        """
        config = kwargs.pop("config", None)
        cache_dir = kwargs.pop("cache_dir", None)
        from_pt = kwargs.pop("from_pt", False)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        from_pipeline = kwargs.pop("_from_pipeline", None)
        from_auto_class = kwargs.pop("_from_auto", False)

        user_agent = {
            "file_type": "model",
            "framework": "flax",
            "from_auto_class": from_auto_class
        }
        if from_pipeline is not None:
            user_agent["using_pipeline"] = from_pipeline

        if is_offline_mode() and not local_files_only:
            logger.info("Offline mode: forcing local_files_only=True")
            local_files_only = True

        # Load config if we don't provide a configuration
        if not isinstance(config, PretrainedConfig):
            config_path = config if config is not None else pretrained_model_name_or_path
            config, model_kwargs = cls.config_class.from_pretrained(
                config_path,
                *model_args,
                cache_dir=cache_dir,
                return_unused_kwargs=True,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                local_files_only=local_files_only,
                use_auth_token=use_auth_token,
                revision=revision,
                _from_auto=from_auto_class,
                _from_pipeline=from_pipeline,
                **kwargs,
            )
        else:
            model_kwargs = kwargs

        # Add the dtype to model_kwargs
        model_kwargs["dtype"] = dtype

        # Load model
        if pretrained_model_name_or_path is not None:
            if os.path.isdir(pretrained_model_name_or_path):
                if from_pt and os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     WEIGHTS_NAME)):
                    # Load from a PyTorch checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                WEIGHTS_NAME)
                elif os.path.isfile(
                        os.path.join(pretrained_model_name_or_path,
                                     FLAX_WEIGHTS_NAME)):
                    # Load from a Flax checkpoint
                    archive_file = os.path.join(pretrained_model_name_or_path,
                                                FLAX_WEIGHTS_NAME)
                else:
                    raise EnvironmentError(
                        f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
                        f"{pretrained_model_name_or_path} or `from_pt` set to False"
                    )
            elif os.path.isfile(
                    pretrained_model_name_or_path) or is_remote_url(
                        pretrained_model_name_or_path):
                archive_file = pretrained_model_name_or_path
            else:
                archive_file = hf_bucket_url(
                    pretrained_model_name_or_path,
                    filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
                    revision=revision,
                )

            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(
                    archive_file,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    resume_download=resume_download,
                    local_files_only=local_files_only,
                    use_auth_token=use_auth_token,
                    user_agent=user_agent,
                )
            except EnvironmentError as err:
                logger.error(err)
                msg = (
                    f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
                    f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
                    f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
                )
                raise EnvironmentError(msg)

            if resolved_archive_file == archive_file:
                logger.info(f"loading weights file {archive_file}")
            else:
                logger.info(
                    f"loading weights file {archive_file} from cache at {resolved_archive_file}"
                )
        else:
            resolved_archive_file = None

        # init random models
        model = cls(config, *model_args, **model_kwargs)

        if from_pt:
            state = load_pytorch_checkpoint_in_flax_state_dict(
                model, resolved_archive_file)
        else:
            with open(resolved_archive_file, "rb") as state_f:
                try:
                    state = from_bytes(cls, state_f.read())
                except UnpicklingError:
                    raise EnvironmentError(
                        f"Unable to convert {archive_file} to Flax deserializable object. "
                    )

        # if model is base model only use model_prefix key
        if cls.base_model_prefix not in dict(
                model.params) and cls.base_model_prefix in state:
            state = state[cls.base_model_prefix]

        # flatten dicts
        state = flatten_dict(state)
        random_state = flatten_dict(unfreeze(model.params))

        missing_keys = model.required_params - set(state.keys())
        unexpected_keys = set(state.keys()) - model.required_params

        # add missing keys as random parameters
        for missing_key in missing_keys:
            state[missing_key] = random_state[missing_key]

        # remove unexpected keys to not be saved again
        for unexpected_key in unexpected_keys:
            del state[unexpected_key]

        if len(unexpected_keys) > 0:
            logger.warning(
                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
                f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
                f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
                f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
                f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
                f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
            )
        else:
            logger.info(
                f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
            )

        if len(missing_keys) > 0:
            logger.warning(
                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
                f"and are newly initialized: {missing_keys}\n"
                f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
            )
        else:
            logger.info(
                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
                f"If your task is similar to the task the model of the checkpoint was trained on, "
                f"you can already use {model.__class__.__name__} for predictions without further training."
            )

        # set correct parameters
        model.params = unflatten_dict(state)
        return model
Пример #30
0
 def mask(data):
     flat = traverse_util.flatten_dict(data)
     return traverse_util.unflatten_dict(
         {k: fn(k, v)
          for k, v in flat.items()})