def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def test_save_load_to_base(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() base_class = FLAX_MODEL_MAPPING[config.__class__] for model_class in self.all_model_classes: if model_class == base_class: continue model = model_class(config) base_params_from_head = flatten_dict( unfreeze(model.params[model.base_model_prefix])) # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) base_model = base_class.from_pretrained(tmpdirname) base_params = flatten_dict(unfreeze(base_model.params)) for key in base_params_from_head.keys(): max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors pixel_values = jnp.zeros(input_shape, dtype=self.dtype) params_rng, dropout_rng = jax.random.split(rng) dropout_rng, droppath_rng = jax.random.split(dropout_rng) rngs = { "params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng } random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def test_save_load_to_base_pt(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() base_class = FLAX_MODEL_MAPPING[config.__class__] for model_class in self.all_model_classes: if model_class == base_class: continue model = model_class(config) base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) # convert Flax model to PyTorch model pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning pt_model = pt_model_class(config).eval() pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) base_model = base_class.from_pretrained(tmpdirname, from_pt=True) base_params = flatten_dict(unfreeze(base_model.params)) for key in base_params_from_head.keys(): max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} module_init_outputs = self.module.init( rngs, input_ids, attention_mask, position_ids, return_dict=False, ) random_params = module_init_outputs["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) attention_mask = jnp.ones_like(input_ids) head_mask = jnp.ones( (self.config.num_hidden_layers, self.config.num_attention_heads)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def test_push_to_hub_in_organization(self): config = BertConfig(vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37) model = FlaxBertModel(config) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained( os.path.join(tmp_dir, "test-model-flax-org"), push_to_hub=True, use_auth_token=self._token, organization="valid_org", ) new_model = FlaxBertModel.from_pretrained( "valid_org/test-model-flax-org") base_params = flatten_dict(unfreeze(model.params)) new_params = flatten_dict(unfreeze(new_model.params)) for key in base_params.keys(): max_diff = (base_params[key] - new_params[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: encoder_input_shape, decoder_input_shape = input_shape # init input tensors input_ids = jnp.zeros(encoder_input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape if not decoder_batch_size == batch_size: raise ValueError( f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder." ) decoder_position_ids = jnp.broadcast_to( jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init( rngs, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, position_ids, decoder_position_ids, )["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def __init__( self, config: PretrainedConfig, module: nn.Module, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, ): if config is None: raise ValueError("config cannot be None") if module is None: raise ValueError("module cannot be None") # Those are private to be exposed as typed property on derived classes. self._config = config self._module = module # Those are public as their type is generic to every derived classes. self.key = PRNGKey(seed) self.dtype = dtype # randomely initialized parameters random_params = self.init(self.key, input_shape) # save required_params as set self._required_params = set( flatten_dict(unfreeze(random_params)).keys()) self.params = random_params
def params(self, params: Union[Dict, FrozenDict]): if isinstance(params, FrozenDict): params = unfreeze(params) param_keys = set(flatten_dict(params).keys()) if len(self.required_params - param_keys) > 0: raise ValueError( "Some parameters are missing. Make sure that `params` include the following " f"parameters {self.required_params - param_keys}") self._params = freeze(params)
def __init__( self, config: PretrainedConfig, module: nn.Module, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, ): if config is None: raise ValueError("config cannot be None") if module is None: raise ValueError("module cannot be None") # Those are private to be exposed as typed property on derived classes. self._config = config self._module = module # Those are public as their type is generic to every derived classes. self.key = PRNGKey(seed) self.dtype = dtype self.input_shape = input_shape # To check if the model was intialized automatically. self._is_initialized = _do_init if _do_init: # randomly initialized parameters random_params = self.init_weights(self.key, input_shape) params_shape_tree = jax.eval_shape(lambda params: params, random_params) else: init_fn = partial(self.init_weights, input_shape=input_shape) params_shape_tree = jax.eval_shape(init_fn, self.key) logger.info( "Model weights are not initialized as `_do_init` is set to `False`. " f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights." ) # get the shape of the parameters self._params_shape_tree = params_shape_tree # save required_params as set self._required_params = set( flatten_dict(unfreeze(params_shape_tree)).keys()) # initialize the parameters if _do_init: self.params = random_params
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # convert pytorch tensor to numpy pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} random_flax_state_dict = flatten_dict(unfreeze(flax_model.params)) flax_state_dict = {} remove_base_model_prefix = ( flax_model.base_model_prefix not in flax_model.params) and (flax_model.base_model_prefix in set( [k.split(".")[0] for k in pt_state_dict.keys()])) add_base_model_prefix = ( flax_model.base_model_prefix in flax_model.params) and (flax_model.base_model_prefix not in set( [k.split(".")[0] for k in pt_state_dict.keys()])) # Need to change some parameters name to match Flax names so that we don't have to fork any layer for pt_key, pt_tensor in pt_state_dict.items(): pt_tuple_key = tuple(pt_key.split(".")) has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix require_base_model_prefix = (flax_model.base_model_prefix, ) + pt_tuple_key in random_flax_state_dict if remove_base_model_prefix and has_base_model_prefix: pt_tuple_key = pt_tuple_key[1:] elif add_base_model_prefix and require_base_model_prefix: pt_tuple_key = (flax_model.base_model_prefix, ) + pt_tuple_key if pt_tuple_key[ -1] == "weight" and pt_tuple_key not in random_flax_state_dict: pt_tuple_key = pt_tuple_key[:-1] + ("kernel", ) pt_tensor = pt_tensor.T elif pt_tuple_key[-1] == "gamma": pt_tuple_key = pt_tuple_key[:-1] + ("weight", ) elif pt_tuple_key[-1] == "beta": pt_tuple_key = pt_tuple_key[:-1] + ("bias", ) if pt_tuple_key in random_flax_state_dict: if random_flax_state_dict[pt_tuple_key].shape != pt_tensor.shape: raise ValueError( "PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." ) # add unexpected weight so that warning is thrown flax_state_dict[pt_tuple_key] = pt_tensor return unflatten_dict(flax_state_dict)
def params(self, params: Union[Dict, FrozenDict]): # don't set params if the model is not initialized if not self._is_initialized: raise ValueError( "`params` cannot be set from model when the model is created with `_do_init=False`. " "You store the params outside of the model.") if isinstance(params, FrozenDict): params = unfreeze(params) param_keys = set(flatten_dict(params).keys()) if len(self.required_params - param_keys) > 0: raise ValueError( "Some parameters are missing. Make sure that `params` include the following " f"parameters {self.required_params - param_keys}") self._params = params
def init_cache(self, batch_size, max_length): r""" Args: batch_size (`int`): batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. max_length (`int`): maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized cache. """ # init input variables to retrieve cache input_ids = jnp.ones((batch_size, max_length), dtype="i4") attention_mask = jnp.ones_like(input_ids, dtype="i4") position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) init_variables = self.module.init( jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True ) return unfreeze(init_variables["cache"])
def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: batch_size (`int`): batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. max_length (`int`): maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized cache. encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. """ # init input variables to retrieve cache decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) decoder_position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape) def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): decoder_module = module._get_decoder_module() return decoder_module( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, position_ids=decoder_position_ids, **kwargs, ) init_variables = self.module.init( jax.random.PRNGKey(0), decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, decoder_position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], init_cache=True, method= _decoder_forward, # we only need to call the decoder to init the cache ) return unfreeze(init_variables["cache"])
def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, encoder_hidden_states: Optional[jnp.ndarray] = None, encoder_attention_mask: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, past_key_values: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.return_dict if encoder_hidden_states is not None and encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) # prepare encoder inputs if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if position_ids is None: batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be # changed by FlaxXGLMAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False outputs = self.module.apply( inputs, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past_key_values = outputs outputs["past_key_values"] = unfreeze(past_key_values["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past_key_values = outputs outputs = outputs[:1] + (unfreeze( past_key_values["cache"]), ) + outputs[1:] return outputs
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g): start_time = time.time() str_nn_arq = '' for item in nn_arq: str_nn_arq = str_nn_arq + '_{}'.format(item) f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l) f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job) f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job) file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job) # -------------------------------------- # Data n_atoms = 4 batch_size = 768 #1024#768#512#256#128#64#32 Dtr, Dval, Dt = load_data(file_results, N, l) Xtr, gXtr, gXctr, ytr = Dtr Xval, gXval, gXcval, yval = Dval Xt, gXt, gXct, yt = Dt print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape) # -------------------------------- # BATCHES n_complete_batches, leftover = divmod(N, batch_size) n_batches = n_complete_batches + bool(leftover) def data_stream(): rng = onpr.RandomState(0) while True: perm = rng.permutation(N) for i in range(n_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[ batch_idx] batches = data_stream() # -------------------------------- f = open(f_out, 'a+') print('-----------------------------------', file=f) print('Starting time', file=f) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f) print('-----------------------------------', file=f) print(f_out, file=f) print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format( N, n_atoms, l, i0), file=f) print(nn_arq, file=f) print('lr = {}, w decay = {}'.format(lr, w_decay), file=f) print('Activation function = {}'.format(act_fun), file=f) print('N Epoch = {}'.format(n_epochs), file=f) print('rho G = {}'.format(rho_g), file=f) print('-----------------------------------', file=f) f.close() # -------------------------------------- # initialize NN nn_arq.append(3) tuple_nn_arq = tuple(nn_arq) nn_model = NN_adiab(n_atoms, tuple_nn_arq) def get_init_NN_params(key): x = Xtr[0, :] x = x[None, :] # x = jnp.ones((1,Xtr.shape[1])) variables = nn_model.init(key, x) return variables # Initilialize parameters rng = random.PRNGKey(i0) rng, subkey = jax.random.split(rng) params = get_init_NN_params(subkey) f = open(f_out, 'a+') if os.path.isfile(f_w_nn): print('Reading NN parameters from prev calculation!', file=f) print('-----------------------', file=f) nn_dic = jnp.load(f_w_nn, allow_pickle=True) params = unfreeze(params) params['params'] = nn_dic.item()['params'] params = freeze(params) # print(params) f.close() init_params = params # -------------------------------------- # Phys functions @jit def nn_adiab(params, x): y_ad_pred = nn_model.apply(params, x) return y_ad_pred @jit def jac_nn_adiab(params, x): g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :]) return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1])) # -------------------------------------- # training loss functions @jit def f_loss_ad_energy(params, batch): X_inputs, _, _, y_true = batch y_pred = nn_adiab(params, X_inputs) diff_y = y_pred - y_true #Ha2cm* return jnp.linalg.norm(diff_y, axis=0) @jit def f_loss_jac(params, batch): X_inputs, gX_inputs, _, y_true = batch gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs) diff_g_X = gX_pred - gX_inputs # jnp.linalg.norm(diff_g_X,axis=0) diff_g_X0 = diff_g_X[:, 0, :] diff_g_X1 = diff_g_X[:, 1, :] l0 = jnp.linalg.norm(diff_g_X0) l1 = jnp.linalg.norm(diff_g_X1) return jnp.stack([l0, l1]) # ------ @jit def f_loss(params, rho_g, batch): rho_g = jnp.exp(rho_g) loss_ad_energy = f_loss_ad_energy(params, batch) loss_jac_energy = f_loss_jac(params, batch) loss = jnp.vdot(jnp.ones_like(loss_ad_energy), loss_ad_energy) + jnp.vdot(rho_g, loss_jac_energy) return loss # -------------------------------------- # Optimization and Training # Perform a single training step. @jit def train_step(optimizer, rho_g, batch): #, learning_rate_fn, model grad_fn = jax.value_and_grad(f_loss) loss, grad = grad_fn(optimizer.target, rho_g, batch) optimizer = optimizer.apply_gradient(grad) #, {"learning_rate": lr} return optimizer, (loss, grad) # @jit def train(rho_g, nn_params): optimizer = optim.Adam(learning_rate=lr, weight_decay=w_decay).create(nn_params) optimizer = jax.device_put(optimizer) train_loss = [] loss0 = 1E16 loss0_tot = 1E16 itercount = itertools.count() f_params = init_params for epoch in range(n_epochs): for _ in range(n_batches): optimizer, loss_and_grad = train_step(optimizer, rho_g, next(batches)) loss, grad = loss_and_grad # f = open(f_out,'a+') # print(i,loss,file=f) # f.close() train_loss.append(loss) # params = optimizer.target # loss_tot = f_validation(params) nn_params = optimizer.target return nn_params, loss_and_grad, train_loss @jit def val_step(optimizer, nn_params): #, learning_rate_fn, model rho_g_prev = optimizer.target nn_params, loss_and_grad_train, train_loss_iter = train( rho_g_prev, nn_params) loss_train, grad_loss_train = loss_and_grad_train grad_fn_val = jax.value_and_grad(f_loss, argnums=1) loss_val, grad_val = grad_fn_val(nn_params, optimizer.target, Dval) optimizer = optimizer.apply_gradient( grad_val) #, {"learning_rate": lr} return optimizer, nn_params, (loss_val, loss_train, train_loss_iter), (grad_loss_train, grad_val) # Initilialize rho_G rng = random.PRNGKey(0) rng, subkey = jax.random.split(rng) rho_G0 = random.uniform(subkey, shape=(2, ), minval=5E-4, maxval=0.025) rho_G0 = jnp.log(rho_G0) print('Initial lambdas', rho_G0) init_G = rho_G0 # optimizer_out = optim.Adam(learning_rate=2E-4, weight_decay=0.).create(init_G) optimizer_out = jax.device_put(optimizer_out) f_params = init_params for i in range(50000): start_va_time = time.time() optimizer_out, f_params, loss_all, grad_all = val_step( optimizer_out, f_params) rho_g = optimizer_out.target loss_val, loss_train, train_loss_iter = loss_all grad_loss_train, grad_val = grad_all loss0_tot = f_loss(f_params, rho_g, Dt) dict_output = serialization.to_state_dict(f_params) jnp.save(f_w_nn, dict_output) #unfreeze() f = open(f_out, 'a+') # print(i,rho_g, loss0, loss0_tot, (time.time() - start_va_time),file=f) print(i, loss_val, loss_train, (time.time() - start_va_time), file=f) print(jnp.exp(rho_g), file=f) print(grad_val, file=f) # print(train_loss_iter ,file=f) # print(grad_val,file=f) # print(grad_loss_train,file=f) f.close() # -------------------------------------- # Prediction f = open(f_out, 'a+') print('Prediction of the entire data set', file=f) print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f) print('NN : {}'.format(nn_arq), file=f) print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g), file=f) print('Activation function = {}'.format(act_fun), file=f) print('Total points = {}'.format(yt.shape[0]), file=f) y_pred = nn_adiab(f_params, Xt) gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt) diff_y = y_pred - yt rmse_Ha = jnp.linalg.norm(diff_y) rmse_cm = jnp.linalg.norm(Ha2cm * diff_y) mae_Ha = jnp.linalg.norm(diff_y, ord=1) mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1) print('RMSE = {} [Ha]'.format(rmse_Ha), file=f) print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f) print('RMSE = {} [cm-1]'.format(rmse_cm), file=f) print('MAE = {} [Ha]'.format(mae_Ha), file=f) print('MAE = {} [cm-1]'.format(mae_cm), file=f) Dpred = jnp.column_stack((Xt, y_pred)) data_dic = { 'Dtr': Dtr, 'Dpred': Dpred, 'gXpred': gX_pred, 'loss_tr': loss0, 'error_full': rmse_cm, 'N': N, 'l': l, 'i0': i0, 'rho_g': rho_g } jnp.save(file_results, data_dic) print('---------------------------------', file=f) print('Total time = %.6f seconds ---' % ((time.time() - start_time)), file=f) print('---------------------------------', file=f) f.close()
def decode( self, decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): r""" Returns: Example:: >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer >>> import jax.numpy as jnp >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-cased', 'gpt2') >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') >>> text = "My friends are cool but they eat too many carbs." >>> input_ids = tokenizer.encode(text, max_length=1024, return_tensors='np') >>> encoder_outputs = model.encode(input_ids) >>> decoder_start_token_id = model.config.decoder.bos_token_id >>> decoder_input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, encoder_outputs) >>> logits = outputs.logits """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.return_dict encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) if decoder_position_ids is None: if past_key_values is not None: raise ValueError( "Make sure to provide `decoder_position_ids` when passing `past_key_values`." ) decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxBartAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs): projection_module = module._get_projection_module() decoder_module = module._get_decoder_module() # optionally project encoder_hidden_states if projection_module is not None: encoder_hidden_states = projection_module( encoder_hidden_states) return decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs, ) outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past = outputs outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past = outputs outputs = outputs[:1] + (unfreeze(past["cache"]), ) + outputs[1:] return outputs
def main_opt(N, l, i0, nn_arq, act_fun, n_epochs, lr, w_decay, rho_g): start_time = time.time() str_nn_arq = '' for item in nn_arq: str_nn_arq = str_nn_arq + '_{}'.format(item) f_job = 'nn_arq{}_N_{}_i0_{}_l_{}_batch'.format(str_nn_arq, N, i0, l) f_out = '{}/out_opt_{}.txt'.format(r_dir, f_job) f_w_nn = '{}/W_{}.npy'.format(r_dir, f_job) file_results = '{}/data_nh3_{}.npy'.format(r_dir, f_job) # -------------------------------------- # Data n_atoms = 4 batch_size = 768 #1024#768#512#256#128#64#32 Dtr, Dt = load_data(file_results, N, l) Xtr, gXtr, gXctr, ytr = Dtr Xt, gXt, gXct, yt = Dt print(gXtr.shape, gXtr.shape, gXctr.shape, ytr.shape) # -------------------------------- # BATCHES n_complete_batches, leftover = divmod(N, batch_size) n_batches = n_complete_batches + bool(leftover) def data_stream(): rng = onpr.RandomState(0) while True: perm = rng.permutation(N) for i in range(n_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield Xtr[batch_idx], gXtr[batch_idx], gXctr[batch_idx], ytr[ batch_idx] batches = data_stream() # -------------------------------- f = open(f_out, 'a+') print('-----------------------------------', file=f) print('Starting time', file=f) print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M"), file=f) print('-----------------------------------', file=f) print(f_out, file=f) print('N = {}, n_atoms = {}, data_random = {}, NN_random = {}'.format( N, n_atoms, l, i0), file=f) print(nn_arq, file=f) print('lr = {}, w decay = {}'.format(lr, w_decay), file=f) print('Activation function = {}'.format(act_fun), file=f) print('N Epoch = {}'.format(n_epochs), file=f) print('rho G = {}'.format(rho_g), file=f) print('-----------------------------------', file=f) f.close() # -------------------------------------- # initialize NN nn_arq.append(3) tuple_nn_arq = tuple(nn_arq) nn_model = NN_adiab(n_atoms, tuple_nn_arq) def get_init_NN_params(key): x = Xtr[0, :] x = x[None, :] # x = jnp.ones((1,Xtr.shape[1])) variables = nn_model.init(key, x) return variables # Initilialize parameters rng = random.PRNGKey(i0) rng, subkey = jax.random.split(rng) params = get_init_NN_params(subkey) f = open(f_out, 'a+') if os.path.isfile(f_w_nn): print('Reading NN parameters from prev calculation!', file=f) print('-----------------------', file=f) nn_dic = jnp.load(f_w_nn, allow_pickle=True) params = unfreeze(params) params['params'] = nn_dic.item()['params'] params = freeze(params) f.close() init_params = params # -------------------------------------- # Phys functions @jit def nn_adiab(params, x): y_ad_pred = nn_model.apply(params, x) return y_ad_pred @jit def jac_nn_adiab(params, x): g_y_pred = jacrev(nn_adiab, argnums=1)(params, x[None, :]) return jnp.reshape(g_y_pred, (2, g_y_pred.shape[-1])) ''' # WRONG @jit def f_nac_coup_i(gH_diab,eigvect_): #for a single cartesian dimension temp = jnp.dot(gH_diab,eigvect_[:,0]) return jnp.vdot(eigvect_[:,1],temp) @jit def f_nac_coup(params,x): eigval_, eigvect_ = f_adiab(params,x) gy_diab = jac_nn_diab(params,x) gy_diab = jnp.reshape(gy_diab.T,(12,2,2)) g_coup = vmap(f_nac_coup_i,(0,None))(gy_diab,eigvect_) return g_coup ''' # -------------------------------------- # Validation loss functions @jit def f_validation(params): y_pred = nn_adiab(params, Xt) diff_y = y_pred - yt z = jnp.linalg.norm(diff_y) return z @jit def f_jac_validation(params): gX_pred = vmap(jac_nn_adiab, (None, 0))(params, Xt) diff_y = gX_pred - gXt z = jnp.linalg.norm(diff_y) return z ''' @jit def f_nac_validation(params): g_nac_coup = vmap(f_nac_coup,(None,0))(params,Xt) diff_y = g_nac_coup - gXct z = jnp.linalg.norm(diff_y) return z ''' # -------------------------------------- # training loss functions @jit def f_loss_ad_energy(params, batch): X_inputs, _, _, y_true = batch y_pred = nn_adiab(params, X_inputs) diff_y = y_pred - y_true #Ha2cm* loss = jnp.linalg.norm(diff_y) return loss @jit def f_loss_jac(params, batch): X_inputs, gX_inputs, _, y_true = batch gX_pred = vmap(jac_nn_adiab, (None, 0))(params, X_inputs) diff_g_X = gX_pred - gX_inputs return jnp.linalg.norm(diff_g_X) ''' @jit def f_loss_nac(params,batch): X_inputs, _,gXc_inputs,y_true = batch g_nac_coup = vmap(f_nac_coup,(None,0))(params,x) diff_y = g_nac_coup - gXc_inputs z = jnp.linalg.norm(diff_y) return z ''' # ------ @jit def f_loss(params, batch): loss_ad_energy = f_loss_ad_energy(params, batch) # loss_jac_energy = f_loss_jac(params,batch) loss = loss_ad_energy #+ rho_g*loss_jac_energy return loss # -------------------------------------- # Optimization and Training # Perform a single training step. @jit def train_step(optimizer, batch): #, learning_rate_fn, model grad_fn = jax.value_and_grad(f_loss) loss, grad = grad_fn(optimizer.target, batch) optimizer = optimizer.apply_gradient(grad) #, {"learning_rate": lr} return optimizer, loss optimizer = optim.Adam(learning_rate=lr, weight_decay=w_decay).create(init_params) optimizer = jax.device_put(optimizer) loss0 = 1E16 loss0_tot = 1E16 itercount = itertools.count() f_params = init_params for epoch in range(n_epochs): for _ in range(n_batches): optimizer, loss = train_step(optimizer, next(batches)) params = optimizer.target loss_tot = f_validation(params) if epoch % 10 == 0: f = open(f_out, 'a+') print(epoch, loss, loss_tot, file=f) f.close() if loss < loss0: loss0 = loss f = open(f_out, 'a+') print(epoch, loss, loss_tot, file=f) f.close() if loss_tot < loss0_tot: loss0_tot = loss_tot f_params = params dict_output = serialization.to_state_dict(params) jnp.save(f_w_nn, dict_output) #unfreeze() f = open(f_out, 'a+') print('---------------------------------', file=f) print('Training time = %.6f seconds ---' % ((time.time() - start_time)), file=f) print('---------------------------------', file=f) f.close() # -------------------------------------- # Prediction f = open(f_out, 'a+') print('Prediction of the entire data set', file=f) print('N = {}, n_atoms = {}, random = {}'.format(N, n_atoms, i0), file=f) print('NN : {}'.format(nn_arq), file=f) print('lr = {}, w decay = {}, rho G = {}'.format(lr, w_decay, rho_g), file=f) print('Activation function = {}'.format(act_fun), file=f) print('Total points = {}'.format(yt.shape[0]), file=f) y_pred = nn_adiab(f_params, Xt) gX_pred = vmap(jac_nn_adiab, (None, 0))(f_params, Xt) diff_y = y_pred - yt rmse_Ha = jnp.linalg.norm(diff_y) rmse_cm = jnp.linalg.norm(Ha2cm * diff_y) mae_Ha = jnp.linalg.norm(diff_y, ord=1) mae_cm = jnp.linalg.norm(Ha2cm * diff_y, ord=1) print('RMSE = {} [Ha]'.format(rmse_Ha), file=f) print('RMSE(tr) = {} [cm-1]'.format(loss0), file=f) print('RMSE = {} [cm-1]'.format(rmse_cm), file=f) print('MAE = {} [Ha]'.format(mae_Ha), file=f) print('MAE = {} [cm-1]'.format(mae_cm), file=f) Dpred = jnp.column_stack((Xt, y_pred)) data_dic = { 'Dtr': Dtr, 'Dpred': Dpred, 'gXpred': gX_pred, 'loss_tr': loss0, 'error_full': rmse_cm, 'N': N, 'l': l, 'i0': i0, 'rho_g': rho_g } jnp.save(file_results, data_dic) print('---------------------------------', file=f) print('Total time = %.6f seconds ---' % ((time.time() - start_time)), file=f) print('---------------------------------', file=f) f.close()
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], dtype: jnp.dtype = jnp.float32, *model_args, **kwargs): r""" Instantiate a pretrained flax model from a pre-trained model configuration. The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): Can be either: - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing model weights saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this case, ``from_pt`` should be set to :obj:`True`. model_args (sequence of positional arguments, `optional`): All remaning positional arguments will be passed to the underlying model's ``__init__`` method. config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`): Can be either: - an instance of a class derived from :class:`~transformers.PretrainedConfig`, - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`. Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - The model is a model provided by the library (loaded with the `model id` string of a pretrained model). - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by supplying the save directory. - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): Load the model weights from a PyTorch checkpoint save file (see docstring of ``pretrained_model_name_or_path`` argument). force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (:obj:`Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to only look at local files (i.e., do not try to download the model). revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. kwargs (remaining dictionary of keyword arguments, `optional`): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: >>> from transformers import BertConfig, FlaxBertModel >>> # Download model and configuration from huggingface.co and cache. >>> model = FlaxBertModel.from_pretrained('bert-base-cased') >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). >>> model = FlaxBertModel.from_pretrained('./test/saved_model/') >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). >>> config = BertConfig.from_json_file('./pt_model/config.json') >>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config) """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) user_agent = { "file_type": "model", "framework": "flax", "from_auto_class": from_auto_class } if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, _from_auto=from_auto_class, _from_pipeline=from_pipeline, **kwargs, ) else: model_kwargs = kwargs # Add the dtype to model_kwargs model_kwargs["dtype"] = dtype # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if from_pt and os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) elif os.path.isfile( os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) else: raise EnvironmentError( f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory " f"{pretrained_model_name_or_path} or `from_pt` set to False" ) elif os.path.isfile( pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME, revision=revision, ) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, ) except EnvironmentError as err: logger.error(err) msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n" ) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info(f"loading weights file {archive_file}") else: logger.info( f"loading weights file {archive_file} from cache at {resolved_archive_file}" ) else: resolved_archive_file = None # init random models model = cls(config, *model_args, **model_kwargs) if from_pt: state = load_pytorch_checkpoint_in_flax_state_dict( model, resolved_archive_file) else: with open(resolved_archive_file, "rb") as state_f: try: state = from_bytes(cls, state_f.read()) except UnpicklingError: raise EnvironmentError( f"Unable to convert {archive_file} to Flax deserializable object. " ) # if model is base model only use model_prefix key if cls.base_model_prefix not in dict( model.params) and cls.base_model_prefix in state: state = state[cls.base_model_prefix] # flatten dicts state = flatten_dict(state) random_state = flatten_dict(unfreeze(model.params)) missing_keys = model.required_params - set(state.keys()) unexpected_keys = set(state.keys()) - model.required_params # add missing keys as random parameters for missing_key in missing_keys: state[missing_key] = random_state[missing_key] # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: del state[unexpected_key] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.info( f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" ) if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) # set correct parameters model.params = unflatten_dict(state) return model
def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, params: dict = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, dropout_rng: PRNGKey = None, deterministic: bool = True, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.return_dict if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if position_ids is None: position_ids = (attention_mask.cumsum(axis=1) * attention_mask) - 1 # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be # changed by FlaxOPTAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False outputs = self.module.apply( inputs, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, rngs=rngs, mutable=mutable, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past_key_values = outputs outputs["past_key_values"] = unfreeze(past_key_values["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past_key_values = outputs outputs = outputs[:1] + (unfreeze( past_key_values["cache"]), ) + outputs[1:] return outputs
attn = Attn(attn_module=self.attn_module, qkv_features=qkv_features // self.num_heads, out_features=out_features) # evaluate multi-headed-attention. y = attn(inputs_q, inputs_kv, bias) return y.mean(axis=-2) # run it. if __name__ == '__main__': inputs = jnp.ones((8, 97, 256)) rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)} model = MultiHeadDotProductAttention( broadcast_dropout=False, qkv_features=256, out_features=256, attn_module=functools.partial(SoftmaxAttnWDropout, rate=0.1), num_heads=8, batch_axes=(0, ), ) y, params = model.init_with_output(rngs, inputs, inputs) print('input shape: ', inputs.shape) print('parameter shapes:') pprint(jax.tree_map(jnp.shape, unfreeze(params))) print('output shape: ', y.shape)
def block_hessians(params, loss_fn, param_partition_fn, batches_gen, rng_key, num_lanczos_steps=20): """Computes the loss hessian with respect to subsets of model parameters. Subsets are determined by the param_partition_fn, which maps the flattened model parameters to a dict mapping keys to subset of the flattened model parameters. For example, if the flattened param tree is {('a', 'b'): 1.0, ('a', 'c'): 1.0, ('d', 'b'): 2.0} And we partition on the outer key (see partition_tree.outer_key), then the output is {'a': {('a', 'b'): 1.0, ('a', 'c'): 1.0} 'd': {('d', 'b'): 2.0}} Args: params: Replicated pytree of model parameters. loss_fn: non pmapped function with API loss_fn(unreplicated_params, unreplicated_batch) -> scalar loss. param_partition_fn: Maps a flattened pytree to a partitioned dict as described above. batches_gen: Should yield replicated batch so we can call jax.pmap(loss_fn)(params, batch). rng_key: Unreplicated jax PRNGKey. num_lanczos_steps: How many lanczos iterates to do. Returns: A dictionary of results, with a key for each partition of the model parameters (as determined by param_partition_fn). The key maps to another dict with the following key value pairs: max_eig_hess -> The hessian max eigenvalue with respect to the sub params. tridiag_hess -> The tridiagonal matrix output by lanczos. param_names -> The flattened parameter names in this partitian. """ unrep_params = flax.jax_utils.unreplicate(params) flat_dict = flax.traverse_util.flatten_dict(unrep_params) sub_param_groups = param_partition_fn(flat_dict) sub_results = {} # I believe this will basically store a copy of the unreplicated # parameters in the function definition, which may be costly when # len(sub_param_groups) is large? Unclear how garbage colllection will handle # this in jax. def sub_loss(sub_params, batch_rng): new_dict = flat_dict.copy() for tup in sub_params: new_dict[tup] = sub_params[tup] new_params = flax.traverse_util.unflatten_dict(new_dict) return loss_fn(new_params, batch_rng) for key in sub_param_groups: logging.info('Block Hessian eval on %s', key) sub_params = unfreeze(jax_utils.replicate(sub_param_groups[key])) hvp_fn, _, n_params = hessian_computation.get_hvp_fn(sub_loss, sub_params, batches_gen, use_pmap=True) # This was needed to avoid the lint [cell-var-from-loop] error. Not sure # it's needed but to avoid any python weirdness with defining functions in # loop went ahead and implemented this. hvp_cl = functools.partial(hvp_fn, unfreeze(sub_params)) row = {} row['tridiag_hess'], _ = lanczos.lanczos_np(hvp_cl, n_params, num_lanczos_steps, 0, rng_key, verbose=True) evs = np.linalg.eigvalsh(row['tridiag_hess']) row['max_eig_hess'] = np.max(evs) row['param_names'] = [list(sub_param_groups[key].keys())] # The flattened keys are tuples, which doesn't work with the serialization, # so to save this dict the keys need to be strings. sub_results[str(key)] = row return sub_results
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome." ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) dataset["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained config and tokenizer if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if training_args.do_train: column_names = dataset["train"].column_names else: column_names = dataset["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") def tokenize_function(examples): with CaptureLogger(tok_logger) as cl: output = tokenizer(examples[text_column_name]) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model." ) return output tokenized_datasets = dataset.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.block_size is None: block_size = tokenizer.model_max_length if block_size > config.max_position_embeddings: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " "Picking 1024 instead. You can change that default value by passing --block_size xxx." ) block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." ) block_size = min(data_args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) if training_args.do_train: if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = lm_datasets["train"] if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) if training_args.do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = lm_datasets["validation"] if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # TODO: weights should be initialized in pjitted fun, this won't work for REALLY large models # TODO: when loading from pre-trained model we need to make sure the vocab is divisible by num_partitions # GPT2's vocab is odd, we need to resize it for fine-tuning model = FlaxAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) ) # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, ) def get_initial_state(params): state = optimizer.init(params) return tuple(state), params # Get PartitionSpec for model params param_spec = set_partitions(unfreeze(model.params)) # Get the PyTree for opt_state, we don't actually initialize the opt_state yet. params_shapes = jax.tree_map(lambda x: x.shape, model.params) state_shapes = jax.eval_shape(get_initial_state, params_shapes) # get PartitionSpec for opt_state, this is very specific to adamw # TODO: optax returns different state for different optimizers, how can we handle this generically ? # or maybe we don't since in our examples we just use adamw or adafactor def get_opt_spec(x): if isinstance(x, dict): return param_spec return None opt_state_spec, param_spec = jax.tree_map( get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) ) # pjit the get_initial_state function to shard params and init # optimizer state in sharded way p_get_initial_state = pjit( get_initial_state, in_axis_resources=None, out_axis_resources=(opt_state_spec, param_spec), ) # hack: move the inital params to CPU to free up device memory # TODO: allow loading weights on CPU in pre-trained model model.params = jax.tree_map(lambda x: np.asarray(x), model.params) # mesh defination mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count()) # actually initialize the opt_state with mesh(mesh_devices, ("dp", "mp")): opt_state, params = p_get_initial_state(freeze(model.params)) # cross-entropy with z loss def loss_fn(logits, labels, z_loss=0): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] shift_labels = onehot(shift_labels, shift_logits.shape[-1]) shift_logits = shift_logits - jax.lax.stop_gradient(shift_logits.max(axis=-1, keepdims=True)) log_z = jnp.log(jnp.sum(jnp.exp(shift_logits), axis=-1, keepdims=True)) log_softmax = shift_logits - log_z loss = -jnp.sum(shift_labels * log_softmax, axis=-1) loss += (1e-4 * jnp.square(log_z.squeeze(-1))) * z_loss return loss.mean() # Define gradient update step fn # TODO: try to use TrainState instead of passing params and opt_state individually def train_step(params, opt_state, dropout_rng, batch, step): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels, z_loss=1.0) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grads = grad_fn(params) updates, new_opt_state = optimizer.update(grads, opt_state, params) new_params = optax.apply_updates(params, updates) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(step)} return new_params, tuple(new_opt_state), new_dropout_rng, metrics, step + 1 # Define eval fn def eval_step(input_ids, labels, params): logits = model(input_ids=input_ids, params=params, train=False)[0] loss = loss_fn(logits, labels) # metrics return {"loss": loss} p_train_step = pjit( train_step, in_axis_resources=(param_spec, opt_state_spec, None, None, None), out_axis_resources=(param_spec, opt_state_spec, None, None, None), donate_argnums=(0, 1), ) p_eval_step = pjit( eval_step, in_axis_resources=(None, None, param_spec), out_axis_resources=None, ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) global_step = 0 # we are not doing 2D parallelism (yet!), this just does model parallelism with mesh(mesh_devices, ("dp", "mp")): for _ in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset train_metrics = [] train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) params, opt_state, dropout_rng, train_metric, global_step = p_train_step( params, opt_state, dropout_rng, batch, global_step, ) train_metrics.append(train_metric) cur_step = global_step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_steps = len(eval_dataset) // eval_batch_size for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): batch = next(eval_loader) metrics = p_eval_step(batch["input_ids"], batch["labels"], params) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = stack_forest(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") logger.info( f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}" ) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(params) model.save_pretrained( training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub, commit_message=f"Saving weights and logs of step {cur_step}", )
def decode( self, decoder_input_ids, encoder_outputs, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): r""" Returns: Example: ```python >>> from transformers import ViTFeatureExtractor, FlaxVisionEncoderDecoderModel >>> import jax.numpy as jnp >>> from PIL import Image >>> import requests >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") >>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained( ... "google/vit-base-patch16-224-in21k", "gpt2" ... ) >>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values >>> encoder_outputs = model.encode(pixel_values) >>> decoder_start_token_id = model.config.decoder.bos_token_id >>> decoder_input_ids = jnp.ones((pixel_values.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, encoder_outputs) >>> logits = outputs.logits ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.return_dict encoder_hidden_states = encoder_outputs[0] batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) if decoder_position_ids is None: if past_key_values is not None: raise ValueError( "Make sure to provide `decoder_position_ids` when passing `past_key_values`." ) decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxBartAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs): projection_module = module._get_projection_module() decoder_module = module._get_decoder_module() # optionally project encoder_hidden_states if projection_module is not None: encoder_hidden_states = projection_module( encoder_hidden_states) return decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs, ) outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past = outputs outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past = outputs outputs = outputs[:1] + (unfreeze(past["cache"]), ) + outputs[1:] return outputs
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], dtype: jnp.dtype = jnp.float32, *model_args, **kwargs): r""" Instantiate a pretrained flax model from a pre-trained model configuration. The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - A path to a *directory* containing model weights saved using [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, `from_pt` should be set to `True`. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`. **Note that this only specifies the dtype of the computation and does not influence the dtype of model parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and [`~FlaxPreTrainedModel.to_bf16`]. model_args (sequence of positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): Can be either: - an instance of a class derived from [`PretrainedConfig`], - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when: - The model is a model provided by the library (loaded with the *model id* string of a pretrained model). - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the save directory. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a configuration JSON file named *config.json* is found in the directory. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. from_pt (`bool`, *optional*, defaults to `False`): Load the model weights from a PyTorch checkpoint save file (see docstring of `pretrained_model_name_or_path` argument). ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels). force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). Behaves differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with `config`, `**kwargs` will be directly passed to the underlying model's `__init__` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, `kwargs` will be first passed to the configuration class initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that corresponds to a configuration attribute will be used to override said attribute with the supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's `__init__` function. Examples: ```python >>> from transformers import BertConfig, FlaxBertModel >>> # Download model and configuration from huggingface.co and cache. >>> model = FlaxBertModel.from_pretrained("bert-base-cased") >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). >>> model = FlaxBertModel.from_pretrained("./test/saved_model/") >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). >>> config = BertConfig.from_json_file("./pt_model/config.json") >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config) ```""" config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) _do_init = kwargs.pop("_do_init", True) user_agent = { "file_type": "model", "framework": "flax", "from_auto_class": from_auto_class } if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, _from_auto=from_auto_class, _from_pipeline=from_pipeline, **kwargs, ) else: model_kwargs = kwargs # Add the dtype to model_kwargs model_kwargs["dtype"] = dtype # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if from_pt and os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) elif os.path.isfile( os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) # At this stage we don't have a weight file so we will raise an error. elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): raise EnvironmentError( f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " "weights.") else: raise EnvironmentError( f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " f"{pretrained_model_name_or_path}.") elif os.path.isfile( pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=filename, revision=revision, ) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, ) except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " "login` and pass `use_auth_token=True`.") except RevisionNotFoundError: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " "this model name. Check the model page at " f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: if filename == FLAX_WEIGHTS_NAME: has_file_kwargs = { "revision": revision, "proxies": proxies, "use_auth_token": use_auth_token } if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} " "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from " "those weights.") else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} " f"or {WEIGHTS_NAME}.") else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {filename}." ) except HTTPError as err: raise EnvironmentError( f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" f"{err}") except ValueError: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in the cached " f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory " f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n" "Checkout your internet connection or see how to run the library in offline mode at " "'https://huggingface.co/docs/transformers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." ) if resolved_archive_file == archive_file: logger.info(f"loading weights file {archive_file}") else: logger.info( f"loading weights file {archive_file} from cache at {resolved_archive_file}" ) else: resolved_archive_file = None # init random models model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) if from_pt: state = load_pytorch_checkpoint_in_flax_state_dict( model, resolved_archive_file) else: with open(resolved_archive_file, "rb") as state_f: try: state = from_bytes(cls, state_f.read()) except (UnpicklingError, msgpack.exceptions.ExtraData) as e: try: with open(resolved_archive_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please install " "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " "you cloned.") else: raise ValueError from e except (UnicodeDecodeError, ValueError): raise EnvironmentError( f"Unable to convert {archive_file} to Flax deserializable object. " ) # make sure all arrays are stored as jnp.arrays # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # https://github.com/google/flax/issues/1261 if _do_init: state = jax.tree_util.tree_map(jnp.array, state) else: # keep the params on CPU if we don't want to initialize state = jax.tree_util.tree_map( lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) # if model is base model only use model_prefix key if cls.base_model_prefix not in dict( model.params_shape_tree) and cls.base_model_prefix in state: state = state[cls.base_model_prefix] # if model is head model and we are loading weights from base model # we initialize new params dict with base_model_prefix if cls.base_model_prefix in dict( model.params_shape_tree ) and cls.base_model_prefix not in state: state = {cls.base_model_prefix: state} # flatten dicts state = flatten_dict(state) random_state = flatten_dict( unfreeze(model.params if _do_init else model.params_shape_tree)) missing_keys = model.required_params - set(state.keys()) unexpected_keys = set(state.keys()) - model.required_params if missing_keys and not _do_init: logger.warn( f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " f"Make sure to call model.init_weights to initialize the missing weights." ) cls._missing_keys = missing_keys # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys = [] for key in state.keys(): if key in random_state and state[key].shape != random_state[ key].shape: if ignore_mismatched_sizes: mismatched_keys.append( (key, state[key].shape, random_state[key].shape)) state[key] = random_state[key] else: raise ValueError( f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " "model.") # add missing keys as random parameters if we are initializing if missing_keys and _do_init: for missing_key in missing_keys: state[missing_key] = random_state[missing_key] # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: del state[unexpected_key] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.info( f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" ) if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join([ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ]) logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) # dictionary of key: dtypes for the model params param_dtypes = jax.tree_map(lambda x: x.dtype, state) # extract keys of parameters not in jnp.float32 fp16_params = [ k for k in param_dtypes if param_dtypes[k] == jnp.float16 ] bf16_params = [ k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16 ] # raise a warning if any of the parameters are not in jnp.float32 if len(fp16_params) > 0: logger.warning( f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" "You should probably UPCAST the model weights to float32 if this was not intended. " "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." ) if len(bf16_params) > 0: logger.warning( f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" "You should probably UPCAST the model weights to float32 if this was not intended. " "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." ) if _do_init: # set correct parameters model.params = unflatten_dict(state) return model else: return model, unflatten_dict(state)
y = data[:, 0] X = sm.add_constant(X) glm_binom = sm.GLM(y, X, family=sm.families.Binomial()) results = glm_binom.fit() mu = jnp.array(results.params) model = LogisticRegressor() init_key, key = split(key) variables = model.init(init_key, X) output = model.apply(variables, X) learning_rate = 1e-3 optimizer = optax.adam(learning_rate) variables = unfreeze(variables) variables['params']['Dense_0']['kernel'] = mu.reshape((-1, 1)) variables = freeze(variables) alpha = 1. nfeatures = tree_map(lambda x: x.shape[0], variables) loglikelihood_fn, logprior_fn = make_fns_for_posterior(model.apply, alpha) lambda_best, avg_lower_bounds = ffvb.vb_gauss_chol(key, loglikelihood_fn, logprior_fn, (X, y), optimizer, variables, lower_triangular=None, num_samples=20, window_size=10,
def __call__( self, input_ids, attention_mask=None, position_ids=None, params: dict = None, past_key_values: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.return_dict batch_size, sequence_length = input_ids.shape if position_ids is None: if past_key_values is not None: raise ValueError( "Make sure to provide `position_ids` when passing `past_key_values`." ) position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False outputs = self.module.apply( inputs, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(position_ids, dtype="i4"), not train, False, output_attentions, output_hidden_states, return_dict, rngs=rngs, mutable=mutable, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past_key_values = outputs outputs["past_key_values"] = unfreeze(past_key_values["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past_key_values = outputs outputs = outputs[:1] + (unfreeze( past_key_values["cache"]), ) + outputs[1:] return outputs
def __call__( self, input_ids, attention_mask=None, position_ids=None, params: dict = None, past_key_values: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.return_dict batch_size, sequence_length = input_ids.shape if position_ids is None: if past_key_values is not None and input_ids.shape[-1] == 1: # if `past_key_values` are passed and input_ids are longer than 1, we are in cached auto-regressive generation. It has to be made sure that position_ids are set correctly cache_shift = flatten_dict( unfreeze(past_key_values))[self._attn_layer_name + ("cache_index", )] position_ids = jnp.broadcast_to( jnp.arange(self.config.max_position_embeddings)[None, :], (batch_size, self.config.max_position_embeddings), ) position_ids = lax.dynamic_slice(position_ids, (0, cache_shift), (batch_size, 1)) else: position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) if attention_mask is None: # if past_key_values are passed we need to create an attention_mask of the same length as `cache_length` if past_key_values is not None: cache_length = flatten_dict( unfreeze(past_key_values))[self._attn_layer_name + ("cached_key", )].shape[1] else: cache_length = sequence_length # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. But since GPT2 uses a causal mask, those positions are masked anyways. Thus we can create a single static attention_mask here, which is more efficient for compilation attention_mask = jnp.ones((batch_size, cache_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False outputs = self.module.apply( inputs, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(position_ids, dtype="i4"), not train, False, output_attentions, output_hidden_states, return_dict, rngs=rngs, mutable=mutable, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past_key_values = outputs outputs["past_key_values"] = unfreeze(past_key_values["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past_key_values = outputs outputs = outputs[:1] + (unfreeze( past_key_values["cache"]), ) + outputs[1:] return outputs