def load_pretrained(variables, url='', default_cfg=None, filter_fn=None): if not url: assert default_cfg is not None and default_cfg['url'] url = default_cfg['url'] state_dict = load_state_dict_from_url(url, transpose=True) source_params, source_state = split_state_dict(state_dict) if filter_fn is not None: # filter after split as we may have modified the split criteria (ie bn running vars) source_params = filter_fn(source_params) source_state = filter_fn(source_state) # FIXME better way to do this? var_unfrozen = unfreeze(variables) missing_keys = [] flat_params = flatten_dict(var_unfrozen['params']) flat_param_keys = set() for k, v in flat_params.items(): flat_k = '.'.join(k) if flat_k in source_params: assert flat_params[k].shape == v.shape flat_params[k] = source_params[flat_k] else: missing_keys.append(flat_k) flat_param_keys.add(flat_k) unexpected_keys = list( set(source_params.keys()).difference(flat_param_keys)) params = freeze(unflatten_dict(flat_params)) flat_state = flatten_dict(var_unfrozen['batch_stats']) flat_state_keys = set() for k, v in flat_state.items(): flat_k = '.'.join(k) if flat_k in source_state: assert flat_state[k].shape == v.shape flat_state[k] = source_state[flat_k] else: missing_keys.append(flat_k) flat_state_keys.add(flat_k) unexpected_keys.extend( list(set(source_state.keys()).difference(flat_state_keys))) batch_stats = freeze(unflatten_dict(flat_state)) if missing_keys: print( f' WARNING: {len(missing_keys)} keys missing while loading state_dict. {str(missing_keys)}' ) if unexpected_keys: print( f' WARNING: {len(unexpected_keys)} unexpected keys found while loading state_dict. {str(unexpected_keys)}' ) return dict(params=params, batch_stats=batch_stats)
def weight_decay_mask(params): params = traverse_util.flatten_dict(params) mask = { k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items() } return traverse_util.unflatten_dict(mask)
def update_GCNN_parity(params): """Adds biases of parity-flip layers to the corresponding no-flip layers. Corrects for changes in GCNN_parity due to PR #1030 in NetKet 3.3. Args: params: a parameter pytree """ # unfreeze just in case, doesn't break with a plain dict params = flatten_dict(unfreeze(params)) to_remove = [] for path in params: if ( len(path) > 1 and path[-2].startswith("equivariant_layers_flip") and path[-1] == "bias" ): alt_path = ( *path[:-2], path[-2].replace("equivariant_layers_flip", "equivariant_layers"), path[-1], ) params[alt_path] = params[alt_path] + params[path] to_remove.append(path) for path in to_remove: del params[path] return unflatten_dict(params)
def test_to_fp32(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config) # cast all params to fp16 and back to fp32 params = model.to_fp16(model.params) params = model.to_fp32(params) # test if all params are in fp32 types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) for name, type_ in types.items(): self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.") # test masking flat_params = flatten_dict(params) key = random.choice(list(flat_params.keys())) # choose a random param mask = {path: path != key for path in flat_params} # don't cast the key mask = unflatten_dict(mask) # cast to fp16 and back to fp32 with mask params = model.to_fp16(model.params) params = model.to_fp32(params, mask) # test if all params are in fp32 except key types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) for name, type_ in types.items(): if name == key: self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.") else: self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} module_init_outputs = self.module.init( rngs, input_ids, attention_mask, position_ids, return_dict=False, ) random_params = module_init_outputs["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) attention_mask = jnp.ones_like(input_ids) head_mask = jnp.ones( (self.config.num_hidden_layers, self.config.num_attention_heads)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask)
def set_partitions(num_partitions, in_dict): rules = _get_partition_rules(num_partitions) replace = _replacement_rules(rules) initd = {k: _unmatched for k in flatten_dict(in_dict)} result = {k: replace(k, v) for k, v in initd.items()} assert _unmatched not in result.values(), 'Incomplete partition spec.' return unflatten_dict(result)
def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params } return traverse_util.unflatten_dict(flat_mask)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors pixel_values = jnp.zeros(input_shape, dtype=self.dtype) params_rng, dropout_rng = jax.random.split(rng) dropout_rng, droppath_rng = jax.random.split(dropout_rng) rngs = { "params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng } random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def merge_nested_dicts(*ds): merged = {} for d in map(flatten_dict, map(flax.core.unfreeze, ds)): if any(k in merged.keys() for k in d.keys()): raise ValueError('Key conflict!') merged.update(d) return unflatten_dict(merged)
def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) layer_norm_params = [ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] ] flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask)
def test_unflatten_dict(self): flat_xs = { ('foo', ): 1, ('bar', 'a'): 2, } xs = traverse_util.unflatten_dict(flat_xs) self.assertEqual(xs, {'foo': 1, 'bar': {'a': 2}})
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: """ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. """ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 def conditional_cast(param): if isinstance(param, jnp.ndarray) and jnp.issubdtype( param.dtype, jnp.floating): param = param.astype(dtype) return param if mask is None: return jax.tree_map(conditional_cast, params) flat_params = flatten_dict(params) flat_mask, _ = jax.tree_flatten(mask) for masked, key in zip(flat_mask, flat_params.keys()): if masked: param = flat_params[key] flat_params[key] = conditional_cast(param) return unflatten_dict(flat_params)
def convert_state_dict_from_pt(model_class: ABC, state: Dict, config: PretrainedConfig): """ Converts a PyTorch parameter state dict to an equivalent Flax parameter state dict """ state = {k: v.numpy() for k, v in state.items()} state = model_class.convert_from_pytorch(state, config) state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()}) return state
def _map_over_modules_in_tree(fn, tree_or_leaf): """Helper for mapping function over submodules.""" dict_or_leaf = serialization.to_state_dict(tree_or_leaf) if not isinstance(dict_or_leaf, dict) or dict_or_leaf == {}: return fn('', tree_or_leaf) else: flat_dict = traverse_util.flatten_dict(dict_or_leaf, keep_empty_nodes=True) mapped_flat_dict = {k: fn('_' + '_'.join(k), v) for k, v in _sorted_items(flat_dict)} return serialization.from_state_dict( tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict))
def dict_replace(col, target, leaf_only=True): col_flat = flatten_dict(unfreeze(col)) diff = {} for keys_flat in col_flat.keys(): for tar_key, tar_val in target.items(): if (keys_flat[-1] == tar_key if leaf_only else (tar_key in keys_flat)): diff[keys_flat] = tar_val col_flat.update(diff) col = unflatten_dict(col_flat) return col
def test_flatten_dict_keep_empty(self): xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict(xs, keep_empty_nodes=True) self.assertEqual( flat_xs, { ('foo', ): 1, ('bar', 'a'): 2, ('bar', 'b'): traverse_util.empty_node, }) xs_restore = traverse_util.unflatten_dict(flat_xs) self.assertEqual(xs, xs_restore)
def test_flatten_dict_is_leaf(self): xs = {'foo': {'c': 4}, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict( xs, is_leaf=lambda k, x: len(k) == 1 and len(x) == 2) self.assertEqual(flat_xs, { ('foo', 'c'): 4, ('bar', ): { 'a': 2, 'b': {} }, }) xs_restore = traverse_util.unflatten_dict(flat_xs) self.assertEqual(xs, xs_restore)
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # convert pytorch tensor to numpy pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} random_flax_state_dict = flatten_dict(flax_model.params) flax_state_dict = {} remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and ( flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) ) add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and ( flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) ) # Need to change some parameters name to match Flax names so that we don't have to fork any layer for pt_key, pt_tensor in pt_state_dict.items(): pt_tuple_key = tuple(pt_key.split(".")) has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict if remove_base_model_prefix and has_base_model_prefix: pt_tuple_key = pt_tuple_key[1:] elif add_base_model_prefix and require_base_model_prefix: pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key # Correctly rename weight parameters if pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: pt_tuple_key = pt_tuple_key[:-1] + ("scale",) if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) pt_tensor = pt_tensor.T elif pt_tuple_key[-1] == "gamma": pt_tuple_key = pt_tuple_key[:-1] + ("weight",) elif pt_tuple_key[-1] == "beta": pt_tuple_key = pt_tuple_key[:-1] + ("bias",) if pt_tuple_key in random_flax_state_dict: if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape: raise ValueError( "PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." ) # also add unexpected weight so that warning is thrown flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor) return unflatten_dict(flax_state_dict)
def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = set( [ layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() ] ) flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask)
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): import torch # Load the index flax_state_dict = {} for shard_file in shard_filenames: # load using msgpack utils pt_state_dict = torch.load(shard_file) pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} model_prefix = flax_model.base_model_prefix random_flax_state_dict = flatten_dict(flax_model.params) load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) ) load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) ) # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): pt_tuple_key = tuple(pt_key.split(".")) # remove base model prefix if necessary has_base_model_prefix = pt_tuple_key[0] == model_prefix if load_model_with_head_into_base_model and has_base_model_prefix: pt_tuple_key = pt_tuple_key[1:] # Correctly rename weight parameters flax_key, flax_tensor = rename_key_and_reshape_tensor( pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix ) # add model prefix if necessary require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict if load_base_model_into_model_with_head and require_base_model_prefix: flax_key = (model_prefix,) + flax_key if flax_key in random_flax_state_dict: if flax_tensor.shape != random_flax_state_dict[flax_key].shape: raise ValueError( f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." ) # also add unexpected weight so that warning is thrown flax_state_dict[flax_key] = jnp.asarray(flax_tensor) return unflatten_dict(flax_state_dict)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: encoder_input_shape, decoder_input_shape = input_shape # init input tensors input_ids = jnp.zeros(encoder_input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape if not decoder_batch_size == batch_size: raise ValueError( f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder." ) decoder_position_ids = jnp.broadcast_to( jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init( rngs, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, )["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def update_optimizer(optimizer, t5_data): """Update flax optimizer for T5 model.""" optimizer_data = traverse_util.flatten_dict(optimizer.state_dict()) optimizer_data = {'/'.join(k): v for k, v in optimizer_data.items()} # Shape check. for k, v in jax.tree_map(np.shape, t5_data).items(): if np.shape(optimizer_data[k]) != v: raise ValueError( f'Variable {k} has shape {v} != {np.shape(optimizer_data[k])}') # Dtype check template optimizer against imported T5 arrays. for k, v in jax.tree_map(lambda x: x.dtype, t5_data).items(): if optimizer_data[k].dtype != v: raise ValueError( f'Variable {k} has dtype {v} != {optimizer_data[k].dtype}') optimizer_data = t5_data optimizer_data = traverse_util.unflatten_dict( {tuple(k.split('/')): v for k, v in optimizer_data.items()}) return optimizer.restore_state(optimizer_data)
def update_dense_symm(params, names=["dense_symm", "Dense"]): """Updates DenseSymm kernels in pre-PR#1030 parameter pytrees to the new 3D convention. Args: params: a parameter pytree names: layer names search for, default: those used in RBMSymm and GCNN* """ params = unfreeze(params) # just in case, doesn't break with a plain dict def fix_one_kernel(args): path, array = args if (len(path) > 1 and path[-2] in names and path[-1] == "kernel" and array.ndim == 2): array = jnp.expand_dims(array, 1) return (path, array) return unflatten_dict( dict(map(fix_one_kernel, flatten_dict(params).items())))
def __call__(self, features: Features) -> Features: features = traverse_util.flatten_dict(features) for name in list(features): dtype = features[name].dtype if dtype not in self.types: del features[name] logging.warning( "Removing feature %r because dtype %s is not supported in JAX.", name, dtype) elif isinstance(features[name], tf.SparseTensor): del features[name] logging.warning( "Removing feature %r because sparse tensors are not " "supported in JAX.", name) elif isinstance(features[name], tf.RaggedTensor): del features[name] logging.warning( "Removing feature %r because ragged tensors are not support in " "JAX.", name) features = traverse_util.unflatten_dict(features) return features # pytype: disable=bad-return-type
def restore_state(self, state_dict): """Restore parameter and optimizer state from state dictionary. Adapted from https://github.com/google-research/t5x/blob/main/t5x/optimizers.py. Includes support to handle `optax.EmptyState`. Args: state_dict: Contains desired new parameters and optimizer state Returns: Updated train state. """ params = serialization.from_state_dict(self.params, state_dict["params"]) # Get all the possible keys in the reference optimizer state. flat_ref_opt_state_dict = traverse_util.flatten_dict( serialization.to_state_dict(self.opt_state), keep_empty_nodes=True, sep="/") flat_src_opt_state_dict = dict( traverse_util.flatten_dict(state_dict["opt_state"], sep="/")) # Adding the empty paths back to flat_src_opt_state_dict. for k, v in flat_ref_opt_state_dict.items(): if k in flat_src_opt_state_dict: continue # The key is not in the input state dict, presumably because it # corresponds to an empty dict. if v != traverse_util.empty_node: raise ValueError( f"Failed to restore optimizer state, path {k} is not present " "in the input optimizer state dict.") flat_src_opt_state_dict[k] = v # Restore state from the enhanced state dict. opt_state = serialization.from_state_dict( self.opt_state, traverse_util.unflatten_dict(flat_src_opt_state_dict, sep="/")) return self.replace(params=params, opt_state=opt_state)
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], dtype: jnp.dtype = jnp.float32, *model_args, **kwargs): r""" Instantiate a pretrained flax model from a pre-trained model configuration. The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): Can be either: - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing model weights saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this case, ``from_pt`` should be set to :obj:`True`. model_args (sequence of positional arguments, `optional`): All remaning positional arguments will be passed to the underlying model's ``__init__`` method. config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`): Can be either: - an instance of a class derived from :class:`~transformers.PretrainedConfig`, - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`. Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - The model is a model provided by the library (loaded with the `model id` string of a pretrained model). - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by supplying the save directory. - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): Load the model weights from a PyTorch checkpoint save file (see docstring of ``pretrained_model_name_or_path`` argument). force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (:obj:`Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to only look at local files (i.e., do not try to download the model). revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. kwargs (remaining dictionary of keyword arguments, `optional`): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: >>> from transformers import BertConfig, FlaxBertModel >>> # Download model and configuration from huggingface.co and cache. >>> model = FlaxBertModel.from_pretrained('bert-base-cased') >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). >>> model = FlaxBertModel.from_pretrained('./test/saved_model/') >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). >>> config = BertConfig.from_json_file('./pt_model/config.json') >>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config) """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) user_agent = { "file_type": "model", "framework": "flax", "from_auto_class": from_auto_class } if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, _from_auto=from_auto_class, _from_pipeline=from_pipeline, **kwargs, ) else: model_kwargs = kwargs # Add the dtype to model_kwargs model_kwargs["dtype"] = dtype # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if from_pt and os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) elif os.path.isfile( os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) else: raise EnvironmentError( f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory " f"{pretrained_model_name_or_path} or `from_pt` set to False" ) elif os.path.isfile( pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME, revision=revision, ) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, ) except EnvironmentError as err: logger.error(err) msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n" ) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info(f"loading weights file {archive_file}") else: logger.info( f"loading weights file {archive_file} from cache at {resolved_archive_file}" ) else: resolved_archive_file = None # init random models model = cls(config, *model_args, **model_kwargs) if from_pt: state = load_pytorch_checkpoint_in_flax_state_dict( model, resolved_archive_file) else: with open(resolved_archive_file, "rb") as state_f: try: state = from_bytes(cls, state_f.read()) except UnpicklingError: raise EnvironmentError( f"Unable to convert {archive_file} to Flax deserializable object. " ) # if model is base model only use model_prefix key if cls.base_model_prefix not in dict( model.params) and cls.base_model_prefix in state: state = state[cls.base_model_prefix] # flatten dicts state = flatten_dict(state) random_state = flatten_dict(unfreeze(model.params)) missing_keys = model.required_params - set(state.keys()) unexpected_keys = set(state.keys()) - model.required_params # add missing keys as random parameters for missing_key in missing_keys: state[missing_key] = random_state[missing_key] # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: del state[unexpected_key] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.info( f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" ) if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) # set correct parameters model.params = unflatten_dict(state) return model
def mask(data): flat = traverse_util.flatten_dict(data) return traverse_util.unflatten_dict( {k: fn(k, v) for k, v in flat.items()})