else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) if model_args.model_name_or_path: model = FlaxT5ForConditionalGeneration.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) ) else: model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForT5MLM( tokenizer=tokenizer, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, input_length=max_seq_length, target_length=targets_length, pad_token_id=model.config.pad_token_id, decoder_start_token_id=model.config.decoder_start_token_id, ) # Store some constant num_epochs = int(training_args.num_train_epochs)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_t5_mlm", model_args, data_args, framework="flax") if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO, datefmt="[%X]", ) # Log on each process the small summary: logger = logging.getLogger(__name__) # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" datasets = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) datasets["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.config_name: config = T5Config.from_pretrained( model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer), use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = T5Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = datasets["train"].column_names else: column_names = datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. # Since we make sure that all sequences are of the same length, no attention_mask is needed. def tokenize_function(examples): return tokenizer(examples[text_column_name], return_attention_mask=False) tokenized_datasets = datasets.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token. # To ensure that the input length is `max_seq_length`, we need to increase the maximum length # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly. expanded_inputs_length, targets_length = compute_input_and_target_lengths( inputs_length=max_seq_length, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, ) # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length. def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: list(chain(*examples[k])) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= expanded_inputs_length: total_length = (total_length // expanded_inputs_length) * expanded_inputs_length # Split by chunks of max_len. result = { k: [ t[i:i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length) ] for k, t in concatenated_examples.items() } return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value # might be slower to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) if model_args.model_name_or_path: model = FlaxT5ForConditionalGeneration.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: config.vocab_size = len(tokenizer) model = FlaxT5ForConditionalGeneration( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForT5MLM( tokenizer=tokenizer, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, input_length=max_seq_length, target_length=targets_length, pad_token_id=model.config.pad_token_id, decoder_start_token_id=model.config.decoder_start_token_id, ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = per_device_eval_batch_size * jax.device_count() num_train_steps = len( tokenized_datasets["train"]) // train_batch_size * num_epochs num_of_hosts = jax.process_count() current_host_idx = jax.process_index() # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = set([ layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() ]) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer) # Define gradient update step fn def train_step(state, batch, dropout_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] # compute loss loss = optax.softmax_cross_entropy( logits, onehot(labels, logits.shape[-1])).mean() return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] # compute loss loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) # compute accuracy accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) # summarize metrics metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"]) # Avoid using jax.numpy here in case of TPU training train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate( tqdm(train_batch_idx, desc="Training...", position=1)): samples = [ tokenized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) local_host_model_inputs = { key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx] for key, value in model_inputs.data.items() } # Model forward model_inputs = shard(local_host_model_inputs) state, train_metric, dropout_rngs = p_train_step( state, model_inputs, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:" f" {train_metric['learning_rate'].mean()})") train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" ) # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get( jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics) if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path): config = T5Config.from_pretrained(config_name) flax_model = FlaxT5ForConditionalGeneration(config=config) t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"] # Encoder for layer_index in range(config.num_layers): layer_name = f"layers_{str(layer_index)}" # Self-Attention t5x_attention_key = t5x_model["target"]["encoder"][layer_name][ "attention"]["key"]["kernel"] t5x_attention_out = t5x_model["target"]["encoder"][layer_name][ "attention"]["out"]["kernel"] t5x_attention_query = t5x_model["target"]["encoder"][layer_name][ "attention"]["query"]["kernel"] t5x_attention_value = t5x_model["target"]["encoder"][layer_name][ "attention"]["value"]["kernel"] # Layer Normalization t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name][ "pre_attention_layer_norm"]["scale"] if split_mlp_wi: t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"][ "wi_0"]["kernel"] t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"][ "wi_1"]["kernel"] else: t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"][ "wi"]["kernel"] t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"][ "kernel"] # Layer Normalization t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name][ "pre_mlp_layer_norm"]["scale"] # Assigning flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ "SelfAttention"]["k"]["kernel"] = t5x_attention_key flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ "SelfAttention"]["o"]["kernel"] = t5x_attention_out flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ "SelfAttention"]["q"]["kernel"] = t5x_attention_query flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ "SelfAttention"]["v"]["kernel"] = t5x_attention_value flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"][ "layer_norm"]["weight"] = t5x_attention_layer_norm if split_mlp_wi: flax_model.params["encoder"]["block"][str(layer_index)]["layer"][ "1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0 flax_model.params["encoder"]["block"][str(layer_index)]["layer"][ "1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1 else: flax_model.params["encoder"]["block"][str(layer_index)]["layer"][ "1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ "DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"][ "layer_norm"]["weight"] = t5x_mlp_layer_norm # Only for layer 0: t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"][ "rel_embedding"].T flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"][ "relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding # Assigning t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"] flax_model.params["encoder"]["final_layer_norm"][ "weight"] = t5x_encoder_norm # Decoder for layer_index in range(config.num_decoder_layers): layer_name = f"layers_{str(layer_index)}" # Self-Attention t5x_attention_key = t5x_model["target"]["decoder"][layer_name][ "self_attention"]["key"]["kernel"] t5x_attention_out = t5x_model["target"]["decoder"][layer_name][ "self_attention"]["out"]["kernel"] t5x_attention_query = t5x_model["target"]["decoder"][layer_name][ "self_attention"]["query"]["kernel"] t5x_attention_value = t5x_model["target"]["decoder"][layer_name][ "self_attention"]["value"]["kernel"] # Layer Normalization t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][ layer_name]["pre_self_attention_layer_norm"]["scale"] # Encoder-Decoder-Attention t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name][ "encoder_decoder_attention"]["key"]["kernel"] t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name][ "encoder_decoder_attention"]["out"]["kernel"] t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][ layer_name]["encoder_decoder_attention"]["query"]["kernel"] t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][ layer_name]["encoder_decoder_attention"]["value"]["kernel"] # Layer Normalization t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name][ "pre_cross_attention_layer_norm"]["scale"] # MLP if split_mlp_wi: t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"][ "wi_0"]["kernel"] t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"][ "wi_1"]["kernel"] else: t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"][ "wi"]["kernel"] t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"][ "kernel"] # Layer Normalization tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name][ "pre_mlp_layer_norm"]["scale"] # Assigning flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ "SelfAttention"]["k"]["kernel"] = t5x_attention_key flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ "SelfAttention"]["o"]["kernel"] = t5x_attention_out flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ "SelfAttention"]["q"]["kernel"] = t5x_attention_query flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ "SelfAttention"]["v"]["kernel"] = t5x_attention_value flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"][ "layer_norm"]["weight"] = t5x_pre_attention_layer_norm flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ "EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ "EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ "EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ "EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"][ "layer_norm"]["weight"] = t5x_cross_layer_norm if split_mlp_wi: flax_model.params["decoder"]["block"][str(layer_index)]["layer"][ "2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0 flax_model.params["decoder"]["block"][str(layer_index)]["layer"][ "2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1 else: flax_model.params["decoder"]["block"][str(layer_index)]["layer"][ "2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ "DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"][ "layer_norm"]["weight"] = tx5_mlp_layer_norm # Decoder Normalization tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"] flax_model.params["decoder"]["final_layer_norm"][ "weight"] = tx5_decoder_norm # Only for layer 0: t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"][ "rel_embedding"].T flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"][ "relative_attention_bias"]["embedding"] = t5x_decoder_rel_embedding # Token Embeddings tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"] flax_model.params["shared"]["embedding"] = tx5_token_embeddings # LM Head (only in v1.1 checkpoints) if "logits_dense" in t5x_model["target"]["decoder"]: flax_model.params["lm_head"]["kernel"] = t5x_model["target"][ "decoder"]["logits_dense"]["kernel"] flax_model.save_pretrained(flax_dump_folder_path) print("T5X Model was sucessfully converted!")