def test_load_with_mismatched_shapes(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: if model_class not in get_values(FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): continue with self.subTest(msg=f"Testing {model_class}"): with tempfile.TemporaryDirectory() as tmp_dir: model = model_class(config) model.save_pretrained(tmp_dir) # Fails when we don't set ignore_mismatched_sizes=True with self.assertRaises(ValueError): new_model = FlaxAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) logger = logging.get_logger("transformers.modeling_flax_utils") with CaptureLogger(logger) as cl: new_model = FlaxAutoModelForSequenceClassification.from_pretrained( tmp_dir, num_labels=42, ignore_mismatched_sizes=True ) self.assertIn("the shapes did not match", cl.out) logits = new_model(**inputs_dict)["logits"] self.assertEqual(logits.shape[1], 42)
def main(): args = parse_args() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Handle the repository creation if args.push_to_hub: if args.hub_model_id is None: repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token) else: repo_name = args.hub_model_id repo = Repository(args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named # label if at least two columns are provided. # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this # single column. You can easily tweak this behavior (see below) # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if args.task_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset("glue", args.task_name) else: # Loading the dataset from local csv or json file. data_files = {} if args.train_file is not None: data_files["train"] = args.train_file if args.validation_file is not None: data_files["validation"] = args.validation_file extension = (args.train_file if args.train_file is not None else args.valid_file).split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. # Labels if args.task_name is not None: is_regression = args.task_name == "stsb" if not is_regression: label_list = raw_datasets["train"].features["label"].names num_labels = len(label_list) else: num_labels = 1 else: # Trying to have good defaults here, don't hesitate to tweak to your needs. is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] if is_regression: num_labels = 1 else: # A useful fast method: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique label_list = raw_datasets["train"].unique("label") label_list.sort() # Let's sort it for determinism num_labels = len(label_list) # Load pretrained model and tokenizer config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config) # Preprocessing the datasets if args.task_name is not None: sentence1_key, sentence2_key = task_to_keys[args.task_name] else: # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: sentence1_key, sentence2_key = "sentence1", "sentence2" else: if len(non_label_column_names) >= 2: sentence1_key, sentence2_key = non_label_column_names[:2] else: sentence1_key, sentence2_key = non_label_column_names[0], None # Some models have set the order of the labels to use, so let's make sure we do use it. label_to_id = None if ( model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id and args.task_name is not None and not is_regression ): # Some have all caps in their config, some don't. label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): logger.info( f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " "Using it!" ) label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} else: logger.warning( "Your model seems to have been trained with labels, but they don't match the dataset: ", f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) elif args.task_name is None: label_to_id = {v: i for i, v in enumerate(label_list)} def preprocess_function(examples): # Tokenize the texts texts = ( (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) result = tokenizer(*texts, padding="max_length", max_length=args.max_length, truncation=True) if "label" in examples: if label_to_id is not None: # Map labels to IDs (not necessary for GLUE tasks) result["labels"] = [label_to_id[l] for l in examples["label"]] else: # In all cases, rename the column to labels because the model will expect that. result["labels"] = examples["label"] return result processed_datasets = raw_datasets.map( preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names ) train_dataset = processed_datasets["train"] eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(args.output_dir) summary_writer.hparams(vars(args)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) def write_metric(train_metrics, eval_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(args.num_train_epochs) rng = jax.random.PRNGKey(args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = args.per_device_train_batch_size * jax.local_device_count() eval_batch_size = args.per_device_eval_batch_size * jax.local_device_count() learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate ) state = create_train_state( model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay ) # define step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey ) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") if args.task_name is not None: metric = load_metric("glue", args.task_name) else: metric = load_metric("accuracy") logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) for epoch in range(1, num_epochs + 1): logger.info(f"Epoch {epoch}") logger.info(" Training...") train_start = time.time() train_metrics = [] rng, input_rng = jax.random.split(rng) # train for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size): state, metrics, dropout_rngs = p_train_step(state, batch, dropout_rngs) train_metrics.append(metrics) train_time += time.time() - train_start logger.info(f" Done! Training metrics: {unreplicate(metrics)}") logger.info(" Evaluating...") # evaluate for batch in glue_eval_data_collator(eval_dataset, eval_batch_size): labels = batch.pop("labels") predictions = p_eval_step(state, batch) metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: jnp.array(v) for k, v in batch.items()} labels = batch.pop("labels") predictions = eval_step(unreplicate(state), batch) metric.add_batch(predictions=predictions, references=labels) eval_metric = metric.compute() logger.info(f" Done! Eval metrics: {eval_metric}") cur_step = epoch * (len(train_dataset) // train_batch_size) # Save metrics if has_tensorboard and jax.process_index() == 0: write_metric(train_metrics, eval_metric, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(args.output_dir, params=params) tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) # save the eval metrics in json if jax.process_index() == 0: eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()} path = os.path.join(args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metric, f, indent=4, sort_keys=True)