def test_to_fp32(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config) # cast all params to fp16 and back to fp32 params = model.to_fp16(model.params) params = model.to_fp32(params) # test if all params are in fp32 types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) for name, type_ in types.items(): self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.") # test masking flat_params = flatten_dict(params) key = random.choice(list(flat_params.keys())) # choose a random param mask = {path: path != key for path in flat_params} # don't cast the key mask = unflatten_dict(mask) # cast to fp16 and back to fp32 with mask params = model.to_fp16(model.params) params = model.to_fp32(params, mask) # test if all params are in fp32 except key types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) for name, type_ in types.items(): if name == key: self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.") else: self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} module_init_outputs = self.module.init( rngs, input_ids, attention_mask, position_ids, return_dict=False, ) random_params = module_init_outputs["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def test_push_to_hub_in_organization(self): config = BertConfig(vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37) model = FlaxBertModel(config) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained( os.path.join(tmp_dir, "test-model-flax-org"), push_to_hub=True, use_auth_token=self._token, organization="valid_org", ) new_model = FlaxBertModel.from_pretrained( "valid_org/test-model-flax-org") base_params = flatten_dict(unfreeze(model.params)) new_params = flatten_dict(unfreeze(new_model.params)) for key in base_params.keys(): max_diff = (base_params[key] - new_params[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) attention_mask = jnp.ones_like(input_ids) head_mask = jnp.ones( (self.config.num_hidden_layers, self.config.num_attention_heads)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def test_save_load_to_base_pt(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() base_class = FLAX_MODEL_MAPPING[config.__class__] for model_class in self.all_model_classes: if model_class == base_class: continue model = model_class(config) base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) # convert Flax model to PyTorch model pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning pt_model = pt_model_class(config).eval() pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) base_model = base_class.from_pretrained(tmpdirname, from_pt=True) base_params = flatten_dict(unfreeze(base_model.params)) for key in base_params_from_head.keys(): max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_save_load_to_base(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() base_class = FLAX_MODEL_MAPPING[config.__class__] for model_class in self.all_model_classes: if model_class == base_class: continue model = model_class(config) base_params_from_head = flatten_dict( unfreeze(model.params[model.base_model_prefix])) # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) base_model = base_class.from_pretrained(tmpdirname) base_params = flatten_dict(unfreeze(base_model.params)) for key in base_params_from_head.keys(): max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors pixel_values = jnp.zeros(input_shape, dtype=self.dtype) params_rng, dropout_rng = jax.random.split(rng) dropout_rng, droppath_rng = jax.random.split(dropout_rng) rngs = { "params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng } random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def load_pretrained(variables, url='', default_cfg=None, filter_fn=None): if not url: assert default_cfg is not None and default_cfg['url'] url = default_cfg['url'] state_dict = load_state_dict_from_url(url, transpose=True) source_params, source_state = split_state_dict(state_dict) if filter_fn is not None: # filter after split as we may have modified the split criteria (ie bn running vars) source_params = filter_fn(source_params) source_state = filter_fn(source_state) # FIXME better way to do this? var_unfrozen = unfreeze(variables) missing_keys = [] flat_params = flatten_dict(var_unfrozen['params']) flat_param_keys = set() for k, v in flat_params.items(): flat_k = '.'.join(k) if flat_k in source_params: assert flat_params[k].shape == v.shape flat_params[k] = source_params[flat_k] else: missing_keys.append(flat_k) flat_param_keys.add(flat_k) unexpected_keys = list( set(source_params.keys()).difference(flat_param_keys)) params = freeze(unflatten_dict(flat_params)) flat_state = flatten_dict(var_unfrozen['batch_stats']) flat_state_keys = set() for k, v in flat_state.items(): flat_k = '.'.join(k) if flat_k in source_state: assert flat_state[k].shape == v.shape flat_state[k] = source_state[flat_k] else: missing_keys.append(flat_k) flat_state_keys.add(flat_k) unexpected_keys.extend( list(set(source_state.keys()).difference(flat_state_keys))) batch_stats = freeze(unflatten_dict(flat_state)) if missing_keys: print( f' WARNING: {len(missing_keys)} keys missing while loading state_dict. {str(missing_keys)}' ) if unexpected_keys: print( f' WARNING: {len(unexpected_keys)} unexpected keys found while loading state_dict. {str(unexpected_keys)}' ) return dict(params=params, batch_stats=batch_stats)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: encoder_input_shape, decoder_input_shape = input_shape # init input tensors input_ids = jnp.zeros(encoder_input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape if not decoder_batch_size == batch_size: raise ValueError( f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder." ) decoder_position_ids = jnp.broadcast_to( jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init( rngs, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, )["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) layer_norm_params = [ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] ] flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} return traverse_util.unflatten_dict(flat_mask)
def set_partitions(num_partitions, in_dict): rules = _get_partition_rules(num_partitions) replace = _replacement_rules(rules) initd = {k: _unmatched for k in flatten_dict(in_dict)} result = {k: replace(k, v) for k, v in initd.items()} assert _unmatched not in result.values(), 'Incomplete partition spec.' return unflatten_dict(result)
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: """ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. """ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 def conditional_cast(param): if isinstance(param, jnp.ndarray) and jnp.issubdtype( param.dtype, jnp.floating): param = param.astype(dtype) return param if mask is None: return jax.tree_map(conditional_cast, params) flat_params = flatten_dict(params) flat_mask, _ = jax.tree_flatten(mask) for masked, key in zip(flat_mask, flat_params.keys()): if masked: param = flat_params[key] flat_params[key] = conditional_cast(param) return unflatten_dict(flat_params)
def update_GCNN_parity(params): """Adds biases of parity-flip layers to the corresponding no-flip layers. Corrects for changes in GCNN_parity due to PR #1030 in NetKet 3.3. Args: params: a parameter pytree """ # unfreeze just in case, doesn't break with a plain dict params = flatten_dict(unfreeze(params)) to_remove = [] for path in params: if ( len(path) > 1 and path[-2].startswith("equivariant_layers_flip") and path[-1] == "bias" ): alt_path = ( *path[:-2], path[-2].replace("equivariant_layers_flip", "equivariant_layers"), path[-1], ) params[alt_path] = params[alt_path] + params[path] to_remove.append(path) for path in to_remove: del params[path] return unflatten_dict(params)
def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params } return traverse_util.unflatten_dict(flat_mask)
def __init__( self, config: PretrainedConfig, module: nn.Module, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, ): if config is None: raise ValueError("config cannot be None") if module is None: raise ValueError("module cannot be None") # Those are private to be exposed as typed property on derived classes. self._config = config self._module = module # Those are public as their type is generic to every derived classes. self.key = PRNGKey(seed) self.dtype = dtype # randomely initialized parameters random_params = self.init(self.key, input_shape) # save required_params as set self._required_params = set( flatten_dict(unfreeze(random_params)).keys()) self.params = random_params
def test_flatten_dict(self): xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict(xs) self.assertEqual(flat_xs, { ('foo', ): 1, ('bar', 'a'): 2, })
def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask)
def weight_decay_mask(params): params = traverse_util.flatten_dict(params) mask = { k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items() } return traverse_util.unflatten_dict(mask)
def get_suffix_module_pairs(module_tree) -> List[Tuple[str, Type["Module"]]]: """Helper for naming pytrees of submodules.""" if isinstance(module_tree, Module): return [('', module_tree)] else: flat_tree = traverse_util.flatten_dict( serialization.to_state_dict(module_tree)) return [('_' + '_'.join(k), v) for k, v in flat_tree.items()]
def _get_suffix_value_pairs( tree_or_leaf: Any) -> List[Tuple[str, Type["Module"]]]: """Helper for naming pytrees of submodules.""" dict_or_leaf = serialization.to_state_dict(tree_or_leaf) if dict_or_leaf == {} or not isinstance(dict_or_leaf, dict): return [('', tree_or_leaf)] else: flat_dict = traverse_util.flatten_dict(dict_or_leaf) return [('_' + '_'.join(k), v) for k, v in flat_dict.items()]
def params(self, params: Union[Dict, FrozenDict]): if isinstance(params, FrozenDict): params = unfreeze(params) param_keys = set(flatten_dict(params).keys()) if len(self.required_params - param_keys) > 0: raise ValueError( "Some parameters are missing. Make sure that `params` include the following " f"parameters {self.required_params - param_keys}") self._params = freeze(params)
def partition_nested_dict(d, flat_left_keys): left, right = {}, {} flat_left_keys = set(flat_left_keys) for k, v in flatten_dict(d).items(): if k in flat_left_keys: left[k] = v else: right[k] = v return tuple(map(unflatten_dict, (left, right)))
def dict_replace(col, target, leaf_only=True): col_flat = flatten_dict(unfreeze(col)) diff = {} for keys_flat in col_flat.keys(): for tar_key, tar_val in target.items(): if (keys_flat[-1] == tar_key if leaf_only else (tar_key in keys_flat)): diff[keys_flat] = tar_val col_flat.update(diff) col = unflatten_dict(col_flat) return col
def _map_over_modules_in_tree(fn, tree_or_leaf): """Helper for mapping function over submodules.""" dict_or_leaf = serialization.to_state_dict(tree_or_leaf) if not isinstance(dict_or_leaf, dict) or dict_or_leaf == {}: return fn('', tree_or_leaf) else: flat_dict = traverse_util.flatten_dict(dict_or_leaf, keep_empty_nodes=True) mapped_flat_dict = {k: fn('_' + '_'.join(k), v) for k, v in _sorted_items(flat_dict)} return serialization.from_state_dict( tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict))
def test_flatten_dict_keep_empty(self): xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict(xs, keep_empty_nodes=True) self.assertEqual( flat_xs, { ('foo', ): 1, ('bar', 'a'): 2, ('bar', 'b'): traverse_util.empty_node, }) xs_restore = traverse_util.unflatten_dict(flat_xs) self.assertEqual(xs, xs_restore)
def test_default_params_dtype(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: # check if all params are still in float32 when dtype of computation is half-precision model = model_class(config, dtype=jnp.float16) types = jax.tree_map(lambda x: x.dtype, model.params) types = flatten_dict(types) for name, type_ in types.items(): self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")
def restore_state(self, state_dict): """Restore parameter and optimizer state from state dictionary. Adapted from https://github.com/google-research/t5x/blob/main/t5x/optimizers.py. Includes support to handle `optax.EmptyState`. Args: state_dict: Contains desired new parameters and optimizer state Returns: Updated train state. """ params = serialization.from_state_dict(self.params, state_dict["params"]) # Get all the possible keys in the reference optimizer state. flat_ref_opt_state_dict = traverse_util.flatten_dict( serialization.to_state_dict(self.opt_state), keep_empty_nodes=True, sep="/") flat_src_opt_state_dict = dict( traverse_util.flatten_dict(state_dict["opt_state"], sep="/")) # Adding the empty paths back to flat_src_opt_state_dict. for k, v in flat_ref_opt_state_dict.items(): if k in flat_src_opt_state_dict: continue # The key is not in the input state dict, presumably because it # corresponds to an empty dict. if v != traverse_util.empty_node: raise ValueError( f"Failed to restore optimizer state, path {k} is not present " "in the input optimizer state dict.") flat_src_opt_state_dict[k] = v # Restore state from the enhanced state dict. opt_state = serialization.from_state_dict( self.opt_state, traverse_util.unflatten_dict(flat_src_opt_state_dict, sep="/")) return self.replace(params=params, opt_state=opt_state)
def test_flatten_dict_is_leaf(self): xs = {'foo': {'c': 4}, 'bar': {'a': 2, 'b': {}}} flat_xs = traverse_util.flatten_dict( xs, is_leaf=lambda k, x: len(k) == 1 and len(x) == 2) self.assertEqual(flat_xs, { ('foo', 'c'): 4, ('bar', ): { 'a': 2, 'b': {} }, }) xs_restore = traverse_util.unflatten_dict(flat_xs) self.assertEqual(xs, xs_restore)
def __init__( self, config: PretrainedConfig, module: nn.Module, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, ): if config is None: raise ValueError("config cannot be None") if module is None: raise ValueError("module cannot be None") # Those are private to be exposed as typed property on derived classes. self._config = config self._module = module # Those are public as their type is generic to every derived classes. self.key = PRNGKey(seed) self.dtype = dtype self.input_shape = input_shape # To check if the model was intialized automatically. self._is_initialized = _do_init if _do_init: # randomly initialized parameters random_params = self.init_weights(self.key, input_shape) params_shape_tree = jax.eval_shape(lambda params: params, random_params) else: init_fn = partial(self.init_weights, input_shape=input_shape) params_shape_tree = jax.eval_shape(init_fn, self.key) logger.info( "Model weights are not initialized as `_do_init` is set to `False`. " f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights." ) # get the shape of the parameters self._params_shape_tree = params_shape_tree # save required_params as set self._required_params = set( flatten_dict(unfreeze(params_shape_tree)).keys()) # initialize the parameters if _do_init: self.params = random_params