def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join(['%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys())]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) io_string = '' inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) io_string += inps[-1] + ' < ' + outs[-1] + ' > ' return inps, outs, io_string[:-3] # Remove last separator. def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype(np.int32) try: p = dsl.decode_program(program, id_token_table) return p, p.to_string() except: # pylint: disable=bare-except return None, '' # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch( batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- base_train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) base_eval_config = base_train_config.replace(deterministic=True, train_vq=False) base_predict_config = base_train_config.replace( shift=False, deterministic=True, train_vq=False, decode=True) train_config = models.LatentTransformerConfig( base_cfg=base_train_config, latent_vocab_size=FLAGS.latent_vocab_size, c=FLAGS.c, train_vq=True, commitment_cost_vq=FLAGS.commitment_cost_vq) eval_config = models.LatentTransformerConfig( base_cfg=base_eval_config, latent_vocab_size=FLAGS.latent_vocab_size, c=FLAGS.c, train_vq=True, commitment_cost_vq=FLAGS.commitment_cost_vq) predict_config = models.LatentTransformerConfig( base_cfg=base_predict_config, latent_vocab_size=FLAGS.latent_vocab_size, c=FLAGS.c, train_vq=True, commitment_cost_vq=FLAGS.commitment_cost_vq) # Latent Predictor. lp_train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=FLAGS.latent_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) lp_eval_config = lp_train_config.replace(deterministic=True) lp_predict_config = lp_train_config.replace( shift=False, deterministic=True, decode=True) rng = jax.random.PRNGKey(0) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.LatentProgramTransformer(eval_config) initial_variables = jax.jit(m.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) lp_m = models.ProgramTransformer(lp_eval_config) lp_initial_variables = jax.jit(lp_m.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam( FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) lp_optimizer = optimizer_def.create(lp_initial_variables['params']) state = TrainState(step=0, optimizer=optimizer, model_state=initial_variables['vqvae'], lp_optimizer=lp_optimizer) # Don't keep a copy of the initial model. del initial_variables, lp_initial_variables train_rngs = jax.random.split(rng, jax.local_device_count()) start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. state = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), state) # Grab last step. start_step = int(state.step) logging.info('Found model checkpointed at step %d.', start_step) state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap( functools.partial( train_step, bos_token=bos_token, eos_token=eos_token, learning_rate_fn=learning_rate_fn, config=train_config, lp_config=lp_train_config), axis_name='batch', static_broadcasted_argnums=(4,)) p_eval_step = jax.pmap( functools.partial( eval_step, bos_token=bos_token, eos_token=eos_token, config=eval_config, lp_config=lp_eval_config), axis_name='batch') p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config, lp_config=lp_predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, bos_token=bos_token, eos_token=eos_token, max_decode_len=FLAGS.max_program_length, config=predict_config, lp_config=lp_predict_config), axis_name='batch', static_broadcasted_argnums=(5,)) metrics_all = [] latent_metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) state, metrics, latent_metrics, train_rngs = p_train_step( state, inputs, outputs, programs, step <= FLAGS.num_pretrain_steps, train_rng=train_rngs) metrics, latent_metrics = jax.tree_map(np.array, (metrics, latent_metrics)) metrics_all.append(metrics) latent_metrics_all.append(latent_metrics) # Save a Checkpoint if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or step == FLAGS.num_train_steps - 1): if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(state), step) # Periodic metric handling. if not step or step % FLAGS.log_freq != 0: continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) latent_metrics_all = common_utils.get_metrics(latent_metrics_all) metrics_sums = jax.tree_map(jnp.sum, latent_metrics_all) denominator = metrics_sums.pop('denominator') summary.update(jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums)) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f, acc: %.4f', step, summary['loss'], summary['accuracy']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] latent_metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] latent_eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) all_metrics = p_eval_step(state, inputs, outputs, programs) metrics, latent_metrics = jax.tree_map(np.array, all_metrics) eval_metrics.append(metrics) latent_eval_metrics.append(latent_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) latent_eval_metrics = common_utils.get_metrics(latent_eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, latent_eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary.update(jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums)) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f', time.time()-t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [10, 50, 100]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions, latent_predictions = [], [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop inputs, outputs, programs = common_utils.shard(pred_batch) cache, lp_cache = p_init_cache(inputs, outputs, programs) predicted, latent_predicted = p_pred_step(state, inputs, outputs, cache, lp_cache, beam_size) predicted, latent_predicted = map(tohost, (predicted, latent_predicted)) inputs, outputs, programs = map(tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs, io_string = decode_io(inputs[i], outputs[i]) p, p_idx, p_score = eval_predicted( beams, inps, outs, parse_beam_fn=lambda x: decode_program(x)[0]) if p_score >= len(inps): pred_acc += 1 ios.append(io_string) targets.append(decode_program(programs[i])[1]) predictions.append(p.to_string() if p else '') latent_predictions.append( ' '.join(list(np.array(latent_predicted[i, p_idx]).astype(str)))) all_pred_acc, all_pred_denominator = per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n' f'latent_predicted: {latent_predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: logging.info('Prediction time (beam %d): %.4f s step %d, score %.4f.', beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar('predict/score-{}'.format(beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def run_experiment( model_dir, data_dir=None, xid=None, batch_size_per_device=128, eval_frequency=500, checkpoint_frequency=10000, save_checkpoints=True, restore_checkpoint=True, num_eval_steps=None, epochs=None, max_train_steps=1000000, # 1 million max_train_length=512, train_summary_frequency=100, max_eval_length=None, model_cls=models.FlaxLM): """Run experiment. Args: model_dir: Directory to save checkpoints and metrics to. data_dir: Directory to load data. xid: Optional experiment id. batch_size_per_device: Batch size per device. eval_frequency: Steps per eval. checkpoint_frequency: How often to checkpoint. If None, only checkpoint once at end of run. save_checkpoints: If True, checkpoints model according to checkpoint_frequency restore_checkpoint: If True, will restore checkpoint from directory. Useful for robustness to preemption. num_eval_steps: Number of eval steps to take on eval dataset. epochs: Number of train epochs. max_train_steps: Stop training after N steps. max_train_length: Crop training sequences to this length. train_summary_frequency: Frequency to write train metrics. max_eval_length: Maximum eval length. Defaults to max_train_length. model_cls: Model class to use. Returns: FlaxLM resulting from running training. """ if xid is not None: model_dir = os.path.join(model_dir, '%s_l%s' % (str(xid), max_train_length)) tf.enable_v2_behavior() if jax.host_id() == 0: summary_writer = tf_summary.create_file_writer( os.path.join(model_dir, 'metrics'), max_queue=1, flush_millis=1000) train_summary_writer = logging_lib.ScalarSummary( step=None, scope='train/', enable_tf=True, verbose=0) eval_summary_writer = logging_lib.ScalarSummary( step=None, scope='eval/', enable_tf=True, verbose=0) batch_size = batch_size_per_device * jax.local_device_count() max_eval_length = max_eval_length or max_train_length train_files, test_files = data.get_train_test_files(directory=data_dir) train_ds, eval_ds = data.load_dataset( train_files=train_files, test_files=test_files, batch_size=batch_size, max_train_length=max_train_length, max_eval_length=max_eval_length, shuffle_buffer=16384) with contextlib.ExitStack() as stack: # pylint: disable=using-constant-test if jax.host_id() == 0: # Only need metric writer context manager on host 0. stack.enter_context(summary_writer.as_default()) model = model_cls(domain=data.protein_domain, batch_size=batch_size) if restore_checkpoint: try: model.load_checkpoint(model_dir) except ValueError: # No checkpoint to load -> raises ValueError. pass start_step = model.train_step train_ds = train_ds.repeat(epochs) train_iter = iter(train_ds) train_metrics = [] tick = time.time() if jax.host_id() == 0: _write_gin_configs(os.path.join(model_dir, 'config.gin')) num_evals = 0 for step, batch in zip(range(start_step, max_train_steps), train_iter): batch = jax.tree_map(lambda x: x._numpy(), batch) # pylint: disable=protected-access metrics = model.fit_batch(batch) train_metrics.append(metrics) if jax.host_id() == 0 and ((save_checkpoints and checkpoint_frequency and step % checkpoint_frequency == 0 and step > 0) or step == max_train_steps - 1): model.save_checkpoint(model_dir) if (step + 1) % train_summary_frequency == 0: summary = evaluation.combine_metrics(train_metrics) logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_frequency / (tock - tick) tick = tock train_summary_writer('steps per second', steps_per_sec, step) for key, val in summary.items(): if jnp.isnan(val): raise ValueError(f'NaN in {key} at step {step}.') train_summary_writer(key, val, step) # reset metric accumulation for next evaluation cycle. train_metrics = [] if eval_frequency and (step + 1) % eval_frequency == 0: eval_summary = evaluation.evaluate( model=model, eval_ds=eval_ds, num_eval_steps=num_eval_steps) logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer(key, val, step) tf_summary.flush() summary_writer.flush() if num_evals == 0: # Write out config on first eval. _write_gin_configs(os.path.join(model_dir, 'config_after_eval.gin')) num_evals += 1 if jax.host_id() == 0: tf_summary.flush() summary_writer.close() _write_gin_configs(os.path.join(model_dir, 'config_end.gin')) return model
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level="NOTSET", datefmt="[%X]", ) # Log on each process the small summary: logger = logging.getLogger(__name__) # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = datasets["train"].column_names else: column_names = datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) if data_args.line_by_line: # When using line_by_line, we just tokenize each nonempty line. padding = "max_length" if data_args.pad_to_max_length else False def tokenize_function(examples): # Remove empty lines examples = [ line for line in examples if len(line) > 0 and not line.isspace() ] return tokenizer( examples, return_special_tokens_mask=True, padding=padding, truncation=True, max_length=max_seq_length, ) tokenized_datasets = datasets.map( tokenize_function, input_columns=[text_column_name], batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) else: # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more # efficient when it receives the `special_tokens_mask`. def tokenize_function(examples): return tokenizer(examples[text_column_name], return_special_tokens_mask=True) tokenized_datasets = datasets.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) # Main data processing function that will concatenate all texts from our dataset and generate chunks of # max_seq_length. def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: list(chain(*examples[k])) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= max_seq_length: total_length = (total_length // max_seq_length) * max_seq_length # Split by chunks of max_len. result = { k: [ t[i:i + max_seq_length] for i in range(0, total_length, max_seq_length) ] for k, t in concatenated_examples.items() } return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value # might be slower to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForLanguageModeling( tokenizer=tokenizer, mlm_probability=data_args.mlm_probability) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) if model_args.model_name_or_path: model = FlaxAutoModelForMaskedLM.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) else: model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr( jnp, model_args.dtype)) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() num_train_steps = len( tokenized_datasets["train"]) // train_batch_size * num_epochs # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. # Note that this mask is specifically adapted for FlaxBERT-like models. # For other models, one should correct the layer norm parameter naming # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer) # Define gradient update step fn def train_step(state, batch, dropout_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] # compute loss, ignore padded input tokens label_mask = jnp.where(labels > 0, 1.0, 0.0) loss = optax.softmax_cross_entropy( logits, onehot(labels, logits.shape[-1])) * label_mask # take average loss = loss.sum() / label_mask.sum() return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] # compute loss, ignore padded input tokens label_mask = jnp.where(labels > 0, 1.0, 0.0) loss = optax.softmax_cross_entropy( logits, onehot(labels, logits.shape[-1])) * label_mask # compute accuracy accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask # summarize metrics metrics = { "loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum() } metrics = jax.lax.psum(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"]) train_samples_idx = jax.random.permutation( input_rng, jnp.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate( tqdm(train_batch_idx, desc="Training...", position=1)): samples = [ tokenized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward model_inputs = shard(model_inputs.data) state, train_metric, dropout_rngs = p_train_step( state, model_inputs, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward model_inputs = shard(model_inputs.data) metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.sum, eval_metrics) eval_normalizer = eval_metrics.pop("normalizer") eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) # Update progress bar epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get( jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: num_eval_samples = len(tokenized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for _, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward model_inputs = shard(model_inputs.data) metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics) eval_normalizer = eval_metrics.pop("normalizer") eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) try: perplexity = math.exp(eval_metrics["loss"]) except OverflowError: perplexity = float("inf") eval_metrics["perplexity"] = perplexity if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() configure_logger(model_args, training_args) # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split= f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split= f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split="validation", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}", cache_dir=model_args.cache_dir, ) # only normalized-inputs-training is supported feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True) def prepare_dataset(batch): # check that all files have the correct sampling rate batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate) return batch # load audio files into numpy arrays vectorized_datasets = datasets.map( prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names) # filter audio files that are too long vectorized_datasets = vectorized_datasets.filter(lambda data: len(data[ "speech"]) < int(data_args.max_duration_in_seconds * feature_extractor. sampling_rate)) def normalize(batch): return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate) # normalize and transform to `BatchFeatures` vectorized_datasets = vectorized_datasets.map( normalize, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, remove_columns=vectorized_datasets["train"].column_names, ) # pretraining is only supported for "newer" stable layer norm architecture # apply_spec_augment has to be True, mask_feature_prob has to be 0.0 config = Wav2Vec2Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, gradient_checkpointing=model_args.gradient_checkpointing, ) if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": raise ValueError( "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'" ) model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) data_collator = FlaxDataCollatorForWav2Vec2Pretraining( model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) gumbel_rngs = jax.random.split(rng, jax.local_device_count()) num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() num_train_steps = len( vectorized_datasets["train"]) // train_batch_size * num_epochs # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state and define training hyper-parameters state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) num_negatives = model.config.num_negatives contrastive_logits_temperature = model.config.contrastive_logits_temperature num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups diversity_loss_weight = model.config.diversity_loss_weight # Define gradient update step fn def train_step(state, batch, dropout_rng, gumbel_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng) def loss_fn(params): negative_indices = batch.pop("sampled_negative_indices") gumbel_temperature = jnp.clip( model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay**state.step, a_min=model_args.min_gumbel_temperature, ) outputs = state.apply_fn( **batch, gumbel_temperature=gumbel_temperature, params=params, dropout_rng=dropout_rng, gumbel_rng=gumbel_rng, train=True, ) contrastive_loss = compute_contrastive_loss( outputs.projected_quantized_states, outputs.projected_states, negative_indices, batch["mask_time_indices"], contrastive_logits_temperature, num_negatives, ) diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors loss = contrastive_loss + diversity_loss_weight * diversity_loss return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng, new_gumbel_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) # Define eval fn def eval_step(params, batch): negative_indices = batch.pop("sampled_negative_indices") outputs = model(**batch, params=params, train=False) contrastive_loss = compute_contrastive_loss( outputs.projected_quantized_states, outputs.projected_states, negative_indices, batch["mask_time_indices"], contrastive_logits_temperature, num_negatives, ) diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors loss = contrastive_loss + diversity_loss_weight * diversity_loss # summarize metrics metrics = { "loss": loss.mean(), "codevector_perplexity": outputs.codevector_perplexity } metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(vectorized_datasets["train"]) train_samples_idx = jax.random.permutation( input_rng, jnp.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate( tqdm(train_batch_idx, desc="Training...", position=1)): samples = [ vectorized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) model_inputs = shard(model_inputs.data) # Model forward state, train_metric, dropout_rngs, gumbel_rngs = p_train_step( state, model_inputs, dropout_rngs, gumbel_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" ) train_metrics = [] # ======================== Evaluating ============================== num_eval_samples = len(vectorized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ vectorized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward model_inputs = shard(model_inputs.data) metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})" ) # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size) write_eval_metric(summary_writer, eval_metrics, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join(['%s=%s' % (shorten(k), str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys())]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.num_partial_programs, FLAGS.max_program_length) split_io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.num_partial_programs, FLAGS.max_characters) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i+1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs def decode_program(program): """Decode program tokens.""" # Concatenate all partial programs. full_program = [] for p in program: full_program.extend(p[:np.argmax(p == eos_token)].astype(np.int32)) full_program = np.concatenate([full_program, [eos_token]], axis=0) try: return dsl.decode_program(full_program, id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table, num_partial_programs=FLAGS.num_partial_programs) dataset = dataset.padded_batch( batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:], split_io_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:], split_io_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat().prefetch(5) train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace( shift=False, deterministic=True, decode=not FLAGS.slow_decode) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.DecomposeExpandingLayerTransformer( config=eval_config, num_partial_programs=FLAGS.num_partial_programs, use_expanding_layer=FLAGS.use_expanding_layer) initial_variables = jax.jit(m.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) adam_opt_def = optim.Adam( FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = adam_opt_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) if start_step > 0: start_step += 1 # Build Pretraining Model and Optimizer (if specified) # --------------------------------------------------------------------------- pretrain_optimizer = None # Optimizer used for pretrainined split_target = None # Split pretrained model on partial programs. if start_step < FLAGS.num_pretrain_steps: # Load in pretraining optimizer. def filter_fn(path, value): del value if FLAGS.freeze_encoder and path.startswith('/encoder'): return False if FLAGS.freeze_decoder and path.startswith('/decoder'): return False return True trainable_weights = optim.ModelParamTraversal(filter_fn) pretrain_opt_def = optim.MultiOptimizer((trainable_weights, adam_opt_def)) pretrain_optimizer = pretrain_opt_def.create(optimizer.target) if FLAGS.pretrain_checkpoint_format: pretrain_exprs = FLAGS.max_expressions // FLAGS.num_partial_programs checkpoint_dir = FLAGS.pretrain_checkpoint_format.format(pretrain_exprs) if gfile.isdir(checkpoint_dir): # Use the pretrained parameters if no training has occurred yet. if start_step == 0: restore_paths = [] if FLAGS.restore_encoder: restore_paths.append('target/encoder') if FLAGS.restore_decoder: restore_paths.append('target/decoder') pretrain_optimizer = restore_selected_paths( pretrain_optimizer, checkpoint_dir=checkpoint_dir, restore_paths=restore_paths) logging.info('Found model pretrained at %s.', checkpoint_dir) if FLAGS.match_split_encoding: split_model = models.DecomposeExpandingLayerTransformer( config=eval_config, num_partial_programs=1, use_expanding_layer=False) split_program_shape = (FLAGS.per_device_batch_size, 1, FLAGS.max_program_length) split_initial_variables = jax.jit(split_model.init)( init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(split_program_shape, jnp.float32)) split_optimizer = adam_opt_def.create( split_initial_variables['params']) split_optimizer = checkpoints.restore_checkpoint( checkpoint_dir, split_optimizer) split_target = split_optimizer.target else: logging.warn('Could not find model at %s.', checkpoint_dir) if FLAGS.match_split_encoding and (split_target is None): raise RuntimeError('We could not load the pretrained checkpoint, ' 'which is needed to match split embeddings.') learning_rate_fn = create_learning_rate_scheduler(base_learning_rate=FLAGS.lr) p_pretrain_step = jax.pmap( functools.partial( pretrain_step, num_partial_programs=FLAGS.num_partial_programs, learning_rate_fn=learning_rate_fn, config=train_config, use_expanding_layer=FLAGS.use_expanding_layer, split_params=split_target), axis_name='batch') p_train_step = jax.pmap( functools.partial( train_step, num_partial_programs=FLAGS.num_partial_programs, learning_rate_fn=learning_rate_fn, config=train_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_eval_step = jax.pmap( functools.partial( eval_step, num_partial_programs=FLAGS.num_partial_programs, eos_token=eos_token, config=eval_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_init_cache = jax.pmap( functools.partial( initialize_cache, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, config=predict_config, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, eos_token=eos_token, config=predict_config, slow_decode=FLAGS.slow_decode, use_expanding_layer=FLAGS.use_expanding_layer), axis_name='batch', static_broadcasted_argnums=(4,)) p_split_pred_step = jax.pmap( functools.partial( predict_step, num_partial_programs=FLAGS.num_partial_programs, max_decode_len=FLAGS.max_program_length, eos_token=eos_token, config=predict_config, slow_decode=FLAGS.slow_decode, use_expanding_layer=False, use_split_encoding=True, split_params=split_target), axis_name='batch', static_broadcasted_argnums=(4,)) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng # Replicate optimizer. if pretrain_optimizer: pretrain_optimizer = jax_utils.replicate(pretrain_optimizer) optimizer = jax_utils.replicate(optimizer) metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs, split_outputs = ( common_utils.shard(next(train_iter))) if step < FLAGS.num_pretrain_steps: pretrain_optimizer, metrics, train_rngs = p_pretrain_step( pretrain_optimizer, inputs, outputs, programs, split_outputs=split_outputs, pretrain_rng=train_rngs) else: optimizer, metrics, train_rngs = p_train_step( optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) is_last_pretrain_step = step == FLAGS.num_pretrain_steps - 1 is_last_step = step == FLAGS.num_train_steps - 1 if is_last_pretrain_step: optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) # Save a Checkpoint if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or (step % FLAGS.log_freq != 0 and not is_last_step and not is_last_pretrain_step): continue optimizer = maybe_copy_model_from_pretraining( optimizer, pretrain_optimizer, step, adam_opt_def) logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_summary = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time()-t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [1, 10, 12, 24, 48, 96]: t_inference_start = time.time() pred_acc, message = predict_and_compute_score( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_io=decode_io, decode_program=decode_program, beam_size=beam_size, num_partial_programs=FLAGS.num_partial_programs, use_best_first_search=FLAGS.best_first_search, slow_decode=FLAGS.slow_decode) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' logging.info( 'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f', slow_or_fast, beam_size, time.time() - t_inference_start, step, pred_acc) beam_search_or_bfs = 'bfs' if FLAGS.best_first_search else 'beam-search' summary_writer.scalar( 'predict-{}/score-{}-{}'.format(slow_or_fast, beam_search_or_bfs, beam_size), pred_acc, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush() if step < FLAGS.num_pretrain_steps and FLAGS.match_split_encoding: pred_acc, message = predict_and_compute_score( p_pred_step=p_split_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_io=decode_io, decode_program=decode_program, beam_size=beam_size, num_partial_programs=FLAGS.num_partial_programs, use_best_first_search=FLAGS.best_first_search, slow_decode=FLAGS.slow_decode) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' beam_search_or_bfs = ('bfs' if FLAGS.best_first_search else 'beam-search') summary_writer.scalar( 'predict-split-{}/score-{}-{}'.format(slow_or_fast, beam_search_or_bfs, beam_size), pred_acc, step) summary_writer.text('samples-split-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForLanguageModeling( tokenizer=tokenizer, mlm_probability=data_args.mlm_probability) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) if model_args.model_name_or_path: model = FlaxAutoModelForMaskedLM.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) else: model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr( jnp, model_args.dtype)) # Store some constant
def main(args): logdir = os.path.join(args.logdir, args.name) logger = logging.setup_logger(logdir) logger.info(args) logger.info(f'Available devices: {jax.devices()}') # Setup input pipeline dataset_info = input_pipeline.get_dataset_info(args.dataset, 'train') ds_train = input_pipeline.get_data( dataset=args.dataset, mode='train', repeats=None, mixup_alpha=args.mixup_alpha, batch_size=args.batch, shuffle_buffer=args.shuffle_buffer, tfds_data_dir=args.tfds_data_dir, tfds_manual_dir=args.tfds_manual_dir) batch = next(iter(ds_train)) logger.info(ds_train) ds_test = input_pipeline.get_data( dataset=args.dataset, mode='test', repeats=1, batch_size=args.batch_eval, tfds_data_dir=args.tfds_data_dir, tfds_manual_dir=args.tfds_manual_dir) logger.info(ds_test) # Build VisionTransformer architecture model = models.KNOWN_MODELS[args.model] VisionTransformer = model.partial(num_classes=dataset_info['num_classes']) _, params = VisionTransformer.init_by_shape( jax.random.PRNGKey(0), # Discard the "num_local_devices" dimension for initialization. [(batch['image'].shape[1:], batch['image'].dtype.name)]) pretrained_path = os.path.join(args.vit_pretrained_dir, f'{args.model}.npz') params = checkpoint.load_pretrained( pretrained_path=pretrained_path, init_params=params, model_config=models.CONFIGS[args.model], logger=logger) # pmap replicates the models over all TPUs/GPUs vit_fn_repl = jax.pmap(VisionTransformer.call) update_fn_repl = make_update_fn(VisionTransformer.call, args.accum_steps) # Create optimizer and replicate it over all TPUs/GPUs opt = momentum_clip.Optimizer( dtype=args.optim_dtype, grad_norm_clip=args.grad_norm_clip).create(params) opt_repl = flax_utils.replicate(opt) # Delete referenes to the objects that are not needed anymore del opt del params def copyfiles(paths): """Small helper to copy files to args.copy_to using tf.io.gfile.""" if not args.copy_to: return for path in paths: to_path = os.path.join(args.copy_to, args.name, os.path.basename(path)) tf.io.gfile.makedirs(os.path.dirname(to_path)) tf.io.gfile.copy(path, to_path, overwrite=True) logger.info(f'Copied {path} to {to_path}.') total_steps = args.total_steps or ( input_pipeline.DATASET_PRESETS[args.dataset]['total_steps']) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = hyper.create_learning_rate_schedule(total_steps, args.base_lr, args.decay_type, args.warmup_steps) lr_iter = hyper.lr_prefetch_iter(lr_fn, 0, total_steps) update_rngs = jax.random.split( jax.random.PRNGKey(0), jax.local_device_count()) # Run training loop writer = metric_writers.create_default_writer(logdir, asynchronous=False) writer.write_hparams({k: v for k, v in vars(args).items() if v is not None}) logger.info('Starting training loop; initial compile can take a while...') t0 = time.time() for step, batch, lr_repl in zip( range(1, total_steps + 1), input_pipeline.prefetch(ds_train, args.prefetch), lr_iter): opt_repl, loss_repl, update_rngs = update_fn_repl( opt_repl, lr_repl, batch, update_rngs) if step == 1: logger.info(f'First step took {time.time() - t0:.1f} seconds.') t0 = time.time() if args.progress_every and step % args.progress_every == 0: writer.write_scalars(step, dict(train_loss=float(loss_repl[0]))) done = step / total_steps logger.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') copyfiles(glob.glob(f'{logdir}/*')) # Run eval step if ((args.eval_every and step % args.eval_every == 0) or (step == total_steps)): accuracy_test = np.mean([ c for batch in input_pipeline.prefetch(ds_test, args.prefetch) for c in ( np.argmax(vit_fn_repl(opt_repl.target, batch['image']), axis=2) == np.argmax(batch['label'], axis=2)).ravel() ]) lr = float(lr_repl[0]) logger.info(f'Step: {step} ' f'Learning rate: {lr:.7f}, ' f'Test accuracy: {accuracy_test:0.5f}') writer.write_scalars(step, dict(accuracy_test=accuracy_test, lr=lr)) copyfiles(glob.glob(f'{logdir}/*')) if args.output: checkpoint.save(flax_utils.unreplicate(opt_repl.target), args.output) logger.info(f'Stored fine tuned checkpoint to {args.output}') copyfiles([args.output])
def get_data(*, dataset, mode, repeats, batch_size, mixup_alpha=0, tfds_manual_dir=None, inception_crop=True): """Returns dataset for training/eval. Args: dataset: Dataset name. Additionally to the requirement that this dataset must be in tensorflow_datasets, the dataset must be registered in `DATASET_PRESETS` (specifying crop size etc). mode: Must be "train" or "test". repeats: How many times the dataset should be repeated. For indefinite repeats specify None. batch_size: Global batch size. Note that the returned dataset will have dimensions [local_devices, batch_size / local_devices, ...]. mixup_alpha: Coefficient for mixup combination. See https://arxiv.org/abs/1710.09412 tfds_manual_dir: Optional directory that contains downloaded files for tensorflow_dataset preparation. inception_crop: If set to True, tf.image.sample_distorted_bounding_box() will be used. If set to False, tf.image.random_crop() will be used. """ preset = DATASET_PRESETS.get(dataset) if preset is None: raise KeyError( f'Please add "{dataset}" to {__name__}.DATASET_PRESETS"') split = preset[mode] resize_size = preset['resize'] crop_size = preset['crop'] dataset_info = get_dataset_info(dataset, split) data_builder = tfds.builder(dataset) data_builder.download_and_prepare( download_config=tfds.download.DownloadConfig( manual_dir=tfds_manual_dir)) data = data_builder.as_dataset( split=split, decoders={'image': tfds.decode.SkipDecoding()}) decoder = data_builder.info.features['image'].decode_example def _pp(data): im = decoder(data['image']) if mode == 'train': if inception_crop: channels = im.shape[-1] begin, size, _ = tf.image.sample_distorted_bounding_box( tf.shape(im), tf.zeros([0, 0, 4], tf.float32), area_range=(0.05, 1.0), min_object_covered=0, # Don't enforce a minimum area. use_image_if_no_bounding_boxes=True) im = tf.slice(im, begin, size) # Unfortunately, the above operation loses the depth-dimension. So we # need to restore it the manual way. im.set_shape([None, None, channels]) im = tf.image.resize(im, [crop_size, crop_size]) else: im = tf.image.resize(im, [resize_size, resize_size]) im = tf.image.random_crop(im, [crop_size, crop_size, 3]) im = tf.image.flip_left_right(im) else: # usage of crop_size here is intentional im = tf.image.resize(im, [crop_size, crop_size]) im = (im - 127.5) / 127.5 label = tf.one_hot(data['label'], dataset_info['num_classes']) # pylint: disable=no-value-for-parameter return {'image': im, 'label': label} data = data.repeat(repeats) if mode == 'train': data = data.shuffle(min(dataset_info['num_examples'], MAX_IN_MEMORY)) data = data.map(_pp, tf.data.experimental.AUTOTUNE) data = data.batch(batch_size, drop_remainder=True) def _mixup(data): beta_dist = tfp.distributions.Beta(mixup_alpha, mixup_alpha) beta = tf.cast(beta_dist.sample([]), tf.float32) data['image'] = (beta * data['image'] + (1 - beta) * tf.reverse(data['image'], axis=[0])) data['label'] = (beta * data['label'] + (1 - beta) * tf.reverse(data['label'], axis=[0])) return data if mixup_alpha is not None and mixup_alpha > 0.0 and mode == 'train': data = data.map(_mixup, tf.data.experimental.AUTOTUNE) # Shard data such that it can be distributed accross devices num_devices = jax.local_device_count() def _shard(data): data['image'] = tf.reshape(data['image'], [num_devices, -1, crop_size, crop_size, 3]) data['label'] = tf.reshape( data['label'], [num_devices, -1, dataset_info['num_classes']]) return data if num_devices is not None: data = data.map(_shard, tf.data.experimental.AUTOTUNE) return data.prefetch(1)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) target_shape = (config.per_device_batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 0: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial( eval_step, config=eval_config, label_smoothing=config.label_smoothing), axis_name="batch") p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial(predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceContext("train", step_num=step): batch = common_utils.shard( jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step(optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("translate_and_bleu"): exemplars, bleu_score = translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length) writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = (step % config.checkpoint_every_steps == 0 or is_last_step) if config.save_checkpoints and save_checkpoint and jax.host_id( ) == 0: with report_progress.timed("checkpoint"): checkpoints.save_checkpoint( workdir, jax_utils.unreplicate(optimizer), step)
def compute_preconditioners_from_statistics(self, states, hps, step): """Compute preconditioners for statistics.""" statistics = [] num_statistics_per_state = [] original_shapes = [] exponents = [] max_size = 0 prev_preconditioners = [] for state in states: num_statistics = len(state.statistics) num_statistics_per_state.append(num_statistics) original_shapes_for_state = [] if num_statistics > 0: for statistic in state.statistics: exponents.append(2 * num_statistics if hps.exponent_override == 0 else hps.exponent_override) original_shapes_for_state.append(statistic.shape) max_size = max(max_size, statistic.shape[0]) statistics.extend(state.statistics) prev_preconditioners.extend(state.preconditioners) original_shapes.extend(original_shapes_for_state) num_statistics = len(statistics) def pack(mat, max_size): """Pack a matrix to a max_size for inverse on TPUs with static shapes. Args: mat: Matrix for computing inverse pth root. max_size: Matrix size to pack to. Returns: Given M returns [[M, 0], [0, I]] """ size = mat.shape[0] assert size <= max_size if size == max_size: return mat pad_size = max_size - size zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype) zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype) eye = jnp.eye(pad_size, dtype=mat.dtype) mat = jnp.concatenate([mat, zs1], 1) mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0) return mat if not hps.batch_axis_name: num_devices = jax.local_device_count() else: num_devices = lax.psum(1, hps.batch_axis_name) # Pad statistics and exponents to next multiple of num_devices. packed_statistics = [pack(stat, max_size) for stat in statistics] to_pad = -num_statistics % num_devices packed_statistics.extend([ jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad) ]) exponents.extend([1 for _ in range(to_pad)]) # Batch statistics and exponents so that so that leading axis is # num_devices. def _batch(statistics, exponents, num_devices): assert len(statistics) == len(exponents) n = len(statistics) b = int(n / num_devices) batched_statistics = [ jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b) ] batched_exponents = [ jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b) ] return jnp.stack(batched_statistics), jnp.stack(batched_exponents) # Unbatch values across leading axis and return a list of elements. def _unbatch(batched_values): b1, b2 = batched_values.shape[0], batched_values.shape[1] results = [] for v_array in jnp.split(batched_values, b1, 0): for v in jnp.split(jnp.squeeze(v_array), b2, 0): results.append(jnp.squeeze(v)) return results all_statistics, all_exponents = _batch(packed_statistics, exponents, num_devices) def _matrix_inverse_pth_root(xs, ps): mi_pth_root = lambda x, y: matrix_inverse_pth_root( # pylint: disable=g-long-lambda x, y, ridge_epsilon=hps.matrix_eps) preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps) return preconditioners, errors if not hps.batch_axis_name: preconditioners, errors = jax.pmap(_matrix_inverse_pth_root)( all_statistics, all_exponents) preconditioners_flat = _unbatch(preconditioners) errors_flat = _unbatch(errors) else: def _internal_inverse_pth_root_all(): preconditioners = jnp.array(all_statistics) current_replica = lax.axis_index(hps.batch_axis_name) preconditioners, errors = _matrix_inverse_pth_root( all_statistics[current_replica], all_exponents[current_replica]) preconditioners = jax.lax.all_gather(preconditioners, hps.batch_axis_name) errors = jax.lax.all_gather(errors, hps.batch_axis_name) preconditioners_flat = _unbatch(preconditioners) errors_flat = _unbatch(errors) return preconditioners_flat, errors_flat if hps.preconditioning_compute_steps == 1: preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() else: # Passing statistics instead of preconditioners as they are similarly # shaped tensors, as error we are passing is the threshold these will # be ignored. preconditioners_init = packed_statistics errors_init = ([_INVERSE_PTH_ROOT_FAILURE_THRESHOLD] * len(packed_statistics)) init_state = [preconditioners_init, errors_init] perform_step = step % hps.preconditioning_compute_steps == 0 preconditioners_flat, errors_flat = self.fast_cond( perform_step, _internal_inverse_pth_root_all, init_state) def _skip(error): return jnp.logical_or( jnp.isnan(error), error >= _INVERSE_PTH_ROOT_FAILURE_THRESHOLD).astype(error.dtype) def _select_preconditioner(error, new_p, old_p): return lax.cond( _skip(error), lambda _: old_p, lambda _: new_p, operand=None) new_preconditioners_flat = [] for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes, prev_preconditioners, errors_flat): new_preconditioners_flat.append( _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p)) assert len(states) == len(num_statistics_per_state) assert len(new_preconditioners_flat) == num_statistics # Add back empty preconditioners so we that we can set the optimizer state. preconditioners_for_states = [] idx = 0 for num_statistics, state in zip(num_statistics_per_state, states): if num_statistics == 0: preconditioners_for_states.append([]) else: preconditioners_for_state = new_preconditioners_flat[idx:idx + num_statistics] assert len(state.statistics) == len(preconditioners_for_state) preconditioners_for_states.append(preconditioners_for_state) idx += num_statistics new_states = [] for state, new_preconditioners in zip(states, preconditioners_for_states): new_states.append( _ShampooDefaultParamState(state.diagonal_statistics, state.statistics, new_preconditioners, state.diagonal_momentum, state.momentum)) return new_states
def update_step( state, transitions, in_initial_bc_iters, ): def reshape_for_devices(t): rest_t_shape = list(t.shape[1:]) new_shape = [ num_devices, t.shape[0] // num_devices, ] + rest_t_shape return jnp.reshape(t, new_shape) transitions = jax.tree_map(reshape_for_devices, transitions) key, key_alpha, key_critic, key_actor = jax.random.split( state.key, 4) if adaptive_entropy_coefficient: alpha = jnp.exp(state.alpha_params) else: alpha = entropy_coefficient key_critic = jax.random.split(key_critic, jax.local_device_count()) # print(jax.tree_map(lambda t: t.shape, state.q_params)) total_critic_loss_and_aux, q_params, new_target_q_params, q_optimizer_state = pmapped_critic_update( state.q_params, state.target_q_params, state.q_optimizer_state, state.policy_params, alpha, cql_alpha, transitions, key_critic, ) # print(jax.tree_map(lambda t: t.shape, q_params)) total_critic_loss_and_aux = jax.tree_map( jnp.mean, total_critic_loss_and_aux) key_actor = jax.random.split(key_actor, jax.local_device_count()) # if in_initial_bc_iters: # pmapped_actor_update = pmapped_actor_update_in_bc_iters # else: # pmapped_actor_update = pmapped_actor_update_after_bc_iters policy_params, policy_optimizer_state, actor_loss, min_q, avg_log_prob, sn, new_snr_state = pmapped_actor_update( in_initial_bc_iters, state.policy_params, state.policy_optimizer_state, state.q_params, state.target_q_params, alpha, transitions, state.snr_state, key_actor, ) avg_log_prob = jnp.mean(avg_log_prob) critic_loss_aux = total_critic_loss_and_aux[1] # metrics = { # 'critic_loss': critic_loss_aux['critic_loss'], # 'cql_loss': critic_loss_aux['cql_loss'], # 'actor_loss': actor_loss, # } metrics = OrderedDict() metrics['actor_loss'] = jnp.mean(actor_loss) metrics['avg_log_prob'] = avg_log_prob metrics['total_critic_loss'] = total_critic_loss_and_aux[0] metrics['critic_loss'] = critic_loss_aux['critic_loss'] metrics['cql_loss'] = critic_loss_aux['cql_loss'] metrics['q/avg'] = jnp.mean(min_q) metrics['q/std'] = jnp.std(min_q) metrics['q/max'] = jnp.max(min_q) metrics['q/min'] = jnp.min(min_q) metrics['SNR/loss'] = jnp.mean(sn) new_state = TrainingState( policy_optimizer_state=policy_optimizer_state, q_optimizer_state=q_optimizer_state, policy_params=policy_params, q_params=q_params, target_q_params=new_target_q_params, key=key, snr_state=new_snr_state, ) if adaptive_entropy_coefficient and (not in_initial_bc_iters): # Apply alpha gradients alpha_loss, alpha_grads = alpha_grad(state.alpha_params, avg_log_prob) alpha_update, alpha_optimizer_state = alpha_optimizer.update( alpha_grads, state.alpha_optimizer_state) alpha_params = optax.apply_updates(state.alpha_params, alpha_update) metrics['alpha_loss'] = alpha_loss metrics['alpha'] = jnp.exp(alpha_params) new_state = new_state._replace( alpha_optimizer_state=alpha_optimizer_state, alpha_params=alpha_params) else: metrics['alpha_loss'] = 0. metrics['alpha'] = jnp.exp(state.alpha_params) new_state = new_state._replace( alpha_optimizer_state=state.alpha_optimizer_state, alpha_params=state.alpha_params) # metrics['observations_mean'] = jnp.mean( # utils.batch_concat( # jax.tree_map(lambda x: jnp.abs(jnp.mean(x, axis=0)), # transitions.observation))) # metrics['observations_std'] = jnp.mean( # utils.batch_concat( # jax.tree_map(lambda x: jnp.std(x, axis=0), # transitions.observation))) # metrics['next_observations_mean'] = jnp.mean( # utils.batch_concat( # jax.tree_map(lambda x: jnp.abs(jnp.mean(x, axis=0)), # transitions.next_observation))) # metrics['next_observations_std'] = jnp.mean( # utils.batch_concat( # jax.tree_map(lambda x: jnp.std(x, axis=0), # transitions.next_observation))) return new_state, metrics
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) test_render_fn = jax.pmap( # Note rng_keys are useless in eval mode since there's no randomness. # pylint: disable=g-long-lambda lambda key_0, key_1, model, rays: jax.lax.all_gather( model(key_0, key_1, *rays), axis_name="batch"), in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=3, axis_name="batch", ) rng, key = random.split(rng) init_model, init_state = models.get_model(key, dataset.peek(), FLAGS) optimizer_def = optim.Adam(FLAGS.lr_init) optimizer = optimizer_def.create(init_model) state = model_utils.TrainState(step=0, optimizer=optimizer, model_state=init_state) if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) offset = state.step + 1 state = jax_utils.replicate(state) del init_model, init_state if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) t_loop_start = time.time() learning_rate_fn = functools.partial(utils.learning_rate_decay, lr_init=FLAGS.lr_init, lr_final=FLAGS.lr_final, max_steps=FLAGS.max_steps, lr_delay_steps=FLAGS.lr_delay_steps, lr_delay_mult=FLAGS.lr_delay_mult) ptrain_step = jax.pmap(train_step, axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=2) # Prefetch_buffer_size = 3 x batch_size pdataset = jax_utils.prefetch_to_device(dataset, 3) n_local_deices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_deices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. stats_trace = [] for step, batch in zip(range(offset, FLAGS.max_steps + 1), pdataset): lr = learning_rate_fn(step) state, stats, keys = ptrain_step(keys, state, batch, lr) if jax.host_id() == 0: stats_trace.append(stats) if step % FLAGS.gc_every == 0: gc.collect() # --- Train logs start --- # Put the training time visualization before the host_id check as in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. state_to_eval = jax.device_get(jax.tree_map(lambda x: x[0], state)) test_case = next(test_dataset) pred_color, pred_disp, pred_acc = utils.render_image( state_to_eval, test_case["rays"], test_render_fn, keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() == 0: psnr = utils.compute_psnr( ((pred_color - test_case["pixels"])**2).mean()) summary_writer.scalar("test_psnr", psnr, step) summary_writer.image("test_pred_color", pred_color, step) summary_writer.image("test_pred_disp", pred_disp, step) summary_writer.image("test_pred_acc", pred_acc, step) summary_writer.image("test_target", test_case["pixels"], step) if jax.host_id() != 0: # Only log via host 0. continue if step % FLAGS.print_every == 0: summary_writer.scalar("train_loss", stats.loss[0], step) summary_writer.scalar("train_psnr", stats.psnr[0], step) summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step) summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step) summary_writer.scalar("weight_l2", stats.weight_l2[0], step) avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace])) avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace])) stats_trace = [] summary_writer.scalar("train_avg_loss", avg_loss, step) summary_writer.scalar("train_avg_psnr", avg_psnr, step) summary_writer.scalar("learning_rate", lr, step) steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) t_loop_start = time.time() rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("steps_per_sec", steps_per_sec, step) summary_writer.scalar("rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.5f}, " + f"avg_loss={avg_loss:0.5f}, " + f"weight_l2={stats.weight_l2[0]:0.2e}, " + f"lr={lr:0.2e}, " + f"{rays_per_sec:0.3f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state_to_save, state_to_save.step, keep=100) # --- Train logs end --- if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(FLAGS.train_dir, state, int(state.step), keep=100)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name( Path(training_args.output_dir).absolute().name, token=training_args.hub_token ) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'tokens' or the first column if no column called # 'tokens' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir ) else: # Loading the dataset from local csv or json file. data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. if raw_datasets["train"] is not None: column_names = raw_datasets["train"].column_names features = raw_datasets["train"].features else: column_names = raw_datasets["validation"].column_names features = raw_datasets["validation"].features if data_args.text_column_name is not None: text_column_name = data_args.text_column_name elif "tokens" in column_names: text_column_name = "tokens" else: text_column_name = column_names[0] if data_args.label_column_name is not None: label_column_name = data_args.label_column_name elif f"{data_args.task_name}_tags" in column_names: label_column_name = f"{data_args.task_name}_tags" else: label_column_name = column_names[1] # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the # unique labels. def get_label_list(labels): unique_labels = set() for label in labels: unique_labels = unique_labels | set(label) label_list = list(unique_labels) label_list.sort() return label_list if isinstance(features[label_column_name].feature, ClassLabel): label_list = features[label_column_name].feature.names # No need to convert the labels since they are already ints. label_to_id = {i: i for i in range(len(label_list))} else: label_list = get_label_list(raw_datasets["train"][label_column_name]) label_to_id = {l: i for i, l in enumerate(label_list)} num_labels = len(label_list) # Load pretrained model and tokenizer config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, label2id=label_to_id, id2label={i: l for l, i in label_to_id.items()}, finetuning_task=data_args.task_name, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path if config.model_type in {"gpt2", "roberta"}: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, add_prefix_space=True, ) else: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) model = FlaxAutoModelForTokenClassification.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) # Preprocessing the datasets # Tokenize all texts and align the labels with them. def tokenize_and_align_labels(examples): tokenized_inputs = tokenizer( examples[text_column_name], max_length=data_args.max_seq_length, padding="max_length", truncation=True, # We use this argument because the texts in our dataset are lists of words (with a label for each word). is_split_into_words=True, ) labels = [] for i, label in enumerate(examples[label_column_name]): word_ids = tokenized_inputs.word_ids(batch_index=i) previous_word_idx = None label_ids = [] for word_idx in word_ids: # Special tokens have a word id that is None. We set the label to -100 so they are automatically # ignored in the loss function. if word_idx is None: label_ids.append(-100) # We set the label for the first token of each word. elif word_idx != previous_word_idx: label_ids.append(label_to_id[label[word_idx]]) # For the other tokens in a word, we set the label to either the current label or -100, depending on # the label_all_tokens flag. else: label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100) previous_word_idx = word_idx labels.append(label_ids) tokenized_inputs["labels"] = labels return tokenized_inputs processed_raw_datasets = raw_datasets.map( tokenize_and_align_labels, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, remove_columns=raw_datasets["train"].column_names, desc="Running tokenizer on dataset", ) train_dataset = processed_raw_datasets["train"] eval_dataset = processed_raw_datasets["validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(training_args.output_dir) summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(training_args.num_train_epochs) rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) state = create_train_state(model, learning_rate_fn, num_labels=num_labels, training_args=training_args) # define step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey ) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") metric = load_metric("seqeval") def get_labels(y_pred, y_true): # Transform predictions and references tensos to numpy arrays # Remove ignored index (special tokens) true_predictions = [ [label_list[p] for (p, l) in zip(pred, gold_label) if l != -100] for pred, gold_label in zip(y_pred, y_true) ] true_labels = [ [label_list[l] for (p, l) in zip(pred, gold_label) if l != -100] for pred, gold_label in zip(y_pred, y_true) ] return true_predictions, true_labels def compute_metrics(): results = metric.compute() if data_args.return_entity_level_metrics: # Unpack nested dictionaries final_results = {} for key, value in results.items(): if isinstance(value, dict): for n, v in value.items(): final_results[f"{key}_{n}"] = v else: final_results[key] = value return final_results else: return { "precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"], } logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) train_time = 0 step_per_epoch = len(train_dataset) // train_batch_size total_steps = step_per_epoch * num_epochs epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # train for step, batch in enumerate( tqdm( train_data_collator(input_rng, train_dataset, train_batch_size), total=step_per_epoch, desc="Training...", position=1, ) ): state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) train_metrics.append(train_metric) cur_step = (epoch * step_per_epoch) + (step + 1) if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: eval_metrics = {} # evaluate for batch in tqdm( eval_data_collator(eval_dataset, eval_batch_size), total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2, ): labels = batch.pop("labels") predictions = p_eval_step(state, batch) predictions = np.array([pred for pred in chain(*predictions)]) labels = np.array([label for label in chain(*labels)]) labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch( predictions=preds, references=refs, ) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: np.array(v) for k, v in batch.items()} labels = batch.pop("labels") predictions = eval_step(unreplicate(state), batch) labels = np.array(labels) labels[np.array(batch["attention_mask"]) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch( predictions=preds, references=refs, ) eval_metrics = compute_metrics() if data_args.return_entity_level_metrics: logger.info(f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}") else: logger.info( f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc: {eval_metrics['accuracy']})" ) if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" # Eval after training if training_args.do_eval: eval_metrics = {} eval_loader = eval_data_collator(eval_dataset, eval_batch_size) for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2): labels = batch.pop("labels") predictions = p_eval_step(state, batch) predictions = np.array([pred for pred in chain(*predictions)]) labels = np.array([label for label in chain(*labels)]) labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch(predictions=preds, references=refs) # evaluate also on leftover examples (not divisible by batch_size) num_leftover_samples = len(eval_dataset) % eval_batch_size # make sure leftover batch is evaluated on one device if num_leftover_samples > 0 and jax.process_index() == 0: # take leftover samples batch = eval_dataset[-num_leftover_samples:] batch = {k: np.array(v) for k, v in batch.items()} labels = np.array(batch.pop("labels")) predictions = eval_step(unreplicate(state), batch) labels[np.array(batch["attention_mask"]) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch(predictions=preds, references=refs) eval_metrics = compute_metrics() if jax.process_index() == 0: eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def main(argv): del argv # unused arg config = FLAGS.config # Unpack total and warmup steps # TODO(nband): revert this to separate arguments. total_steps = config.total_and_warmup_steps[0] warmup_steps = config.total_and_warmup_steps[1] del config.total_and_warmup_steps config.total_steps = total_steps config.lr.warmup_steps = warmup_steps # Wandb and Checkpointing Setup output_dir = FLAGS.output_dir wandb_run, output_dir = vit_utils.maybe_setup_wandb(config) tf.io.gfile.makedirs(output_dir) logging.info('Saving checkpoints at %s', output_dir) # Dataset Split Flags dist_shift = config.distribution_shift print(f'Distribution Shift: {dist_shift}.') dataset_names, split_names = vit_utils.get_dataset_and_split_names(dist_shift) # LR / Optimization Flags batch_size = config.batch_size grad_clip_norm = config.grad_clip_norm weight_decay = config.weight_decay print('Standard wandb hyperparameters:') print({ 'batch_size': batch_size, 'grad_clip_norm': grad_clip_norm, 'weight_decay': weight_decay, 'total_steps': config.total_steps, 'lr': config.lr }) print('SNGP Params:', config.gp_layer) # Reweighting loss for class imbalance # class_reweight_mode = config.class_reweight_mode # if class_reweight_mode == 'constant': # class_weights = utils.get_diabetic_retinopathy_class_balance_weights() # else: # class_weights = None # Shows the number of available devices. # In a CPU/GPU runtime this will be a single device. # In a TPU runtime this will be 8 cores. print('Number of Jax local devices:', jax.local_devices()) # TODO(nband): fix sigmoid loss issues. assert config.get('loss', None) == 'softmax_xent' seed = config.seed rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): tf.io.gfile.makedirs(output_dir) save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError(f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d dev per host (%d dev total), that is a %d per-device batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing preprocessing function...') # Same preprocessing function for training and evaluation preproc_fn = preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_base_dataset = ub.datasets.get( dataset_names['in_domain_dataset'], split=split_names['train_split'], data_dir=config.get('data_dir')) train_dataset_builder = train_base_dataset._dataset_builder # pylint: disable=protected-access train_ds = input_utils.get_data( dataset=train_dataset_builder, split=split_names['train_split'], rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preproc_fn, shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:]) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') # Load in-domain and OOD validation and/or test datasets. # Please specify the desired shift (Country Shift or Severity Shift) # in the config. eval_iter_splits = vit_utils.init_evaluation_datasets( use_validation=config.use_validation, use_test=config.use_test, dataset_names=dataset_names, split_names=split_names, config=config, preproc_fn=preproc_fn, batch_size_eval=batch_size_eval, local_batch_size_eval=local_batch_size_eval) ntrain_img = input_utils.get_num_examples( train_dataset_builder, split=split_names['train_split'], process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = ntrain_img / batch_size if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get('total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') logging.info('config.model = %s', config.get('model')) # Specify Gaussian process layer configs. gp_config = config.get('gp_layer', {}) model_dict = vit_utils.initialize_model('sngp', config) model, use_gp_layer = model_dict['model'], model_dict['use_gp_layer'] # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @functools.partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32) variables = model.init(rng, dummy_input, train=False) # Split model parameters into trainable and untrainable collections. states, params = variables.pop('params') del variables # Set bias in the head to a low value, such that loss is small initially. params = flax.core.unfreeze(params) if use_gp_layer: # Modify the head parameter in the GP head. params['head']['output_layer']['bias'] = jnp.full_like( params['head']['output_layer']['bias'], config.get('init_head_bias', 0)) else: params['head']['bias'] = jnp.full_like( params['head']['bias'], config.get('init_head_bias', 0)) return params, states rng, rng_init = jax.random.split(rng) params_cpu, states_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @functools.partial(jax.pmap, axis_name='batch') def evaluation_fn(params, states, images, labels): variable_dict = {'params': flax.core.freeze(params), **states} logits, out = model.apply( variable_dict, images, train=False, mean_field_factor=gp_config.get('mean_field_factor', -1.)) losses = getattr(train_utils, config.get('loss', 'softmax_xent'))( logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = batch_size_eval metric_args = jax.lax.all_gather([ logits, labels, out['pre_logits']], axis_name='batch') return ncorrect, loss, n, metric_args # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get('weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0,)) def update_fn(opt, states, lr, reset_covmat, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) def loss_fn(params, states, images, labels): # Specify mutable collection to update untrainable GP parameters. variable_dict = {'params': flax.core.freeze(params), **states} model_results, updated_states = model.apply( variable_dict, images, train=True, rngs={'dropout': rng_model_local}, mutable=list(states.keys()), mean_field_factor=gp_config.get('mean_field_factor', -1.)) logits, _ = model_results loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))( logits=logits, labels=labels) return loss, updated_states # Performs exact covariance update (i.e., reset precision matrix resetting # at begining of new epoch) if covmat_momentum is a null value. if use_gp_layer and gp_config.get('covmat_momentum', -1.) < 0: # Resets precision matrix to Identity * ridge_penalty if at the begining # of a new epoch. This should be done before accumulate gradient. ridge_penalty = gp_config.get('ridge_penalty', 1.) prec_mat_old = states['laplace_covariance']['head']['covmat_layer'][ 'precision_matrix'] prec_mat_new = ( (1. - reset_covmat) * prec_mat_old + reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty) states = flax.core.unfreeze(states) states['laplace_covariance']['head']['covmat_layer'][ 'precision_matrix'] = prec_mat_new states = flax.core.freeze(states) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# (l, s), g = vit_utils.accumulate_gradient_with_states( jax.value_and_grad(loss_fn, has_aux=True), opt.target, states, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None: grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if grad_clip_norm is not None: g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params])) measurements['reset_covmat'] = reset_covmat return opt, s, l, rng, measurements # Set config checkpoint resume path, if provided in args. if config.resume_checkpoint_path is not None: config.resume = config.resume_checkpoint_path default_reinit_params = ('head/output_layer/kernel', 'head/output_layer/bias', 'head/kernel', 'head/bias') rng, train_loop_rngs = jax.random.split(rng) checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=train_loop_rngs, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=states_cpu, default_reinit_params=default_reinit_params, config=config) train_loop_rngs = checkpoint_data.train_loop_rngs opt_cpu = checkpoint_data.optimizer states_cpu = checkpoint_data.fixed_model_states accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar( map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) # Prepare the precision matrix resetting schedule, and pre-fetch it to device. reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0) reset_covmat_iter = train_utils.prefetch_scalar( map(reset_covmat_fn, range(first_step, total_steps)), nprefetch=config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax.jax_utils.replicate(opt_cpu) states_repl = flax.jax_utils.replicate(states_cpu) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. # train_loss = -jnp.inf # val_loss = -jnp.inf # results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. # Makes sure log_eval_steps is same as steps_per_epoch. This is because # the precision matrix needs to be updated fully (at the end of each epoch) # when eval takes place. log_eval_steps = steps_per_epoch if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl, reset_covmat_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter, reset_covmat_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): # TODO(jereliu): Expand to allow precision matrix resetting. (opt_repl, states_repl, loss_value, train_loop_rngs, extra_measurements) = update_fn( opt_repl, states_repl, lr_repl, reset_covmat_repl, train_batch['image'], train_batch['labels'], rng=train_loop_rngs) if jax.process_index() == 0: profiler(step) # Checkpoint saving if train_utils.itstime( step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout(checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. # For GP layer, we will also do the same for untrainable parameters # (`states`). This is ok since `random features` are frozen throughout # pre-training, and `precision matrix` is a finetuning-specific parameters # that will be re-learned in the finetuning task. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( optimizer=opt_cpu, fixed_model_states=states_cpu, train_loop_rngs=train_loop_rngs, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if train_utils.itstime( step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') train_loss = loss_value[0] # Keep to return for reproducibility tests. timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update({ 'learning_rate': lr_repl[0], 'training_loss': train_loss, }) train_measurements.update(flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Report validation performance if train_utils.itstime(step, log_eval_steps, total_steps): write_note('Evaluating on the validation set...') chrono.pause() all_eval_results = {} for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items(): start_time = time.time() # Runs evaluation loop. results_arrs = { 'y_true': [], 'y_pred': [], 'y_pred_entropy': [] } for _, batch in zip(range(eval_steps), eval_iter): batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( # pylint: disable=unused-variable evaluation_fn( opt_repl.target, states_repl, batch['image'], batch['labels'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. # Here we parse batch_metric_args to compute uncertainty metrics. logits, labels, _ = batch_metric_args logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels. int_labels = np.argmax(np.array(labels[0]), axis=-1) probs = np.reshape(probs, (probs.shape[0] * probs.shape[1], -1)) int_labels = int_labels.flatten() y_pred = probs[:, 1] results_arrs['y_true'].append(int_labels) results_arrs['y_pred'].append(y_pred) # Entropy is computed at the per-epoch level (see below). results_arrs['y_pred_entropy'].append(probs) results_arrs['y_true'] = np.concatenate(results_arrs['y_true'], axis=0) results_arrs['y_pred'] = np.concatenate( results_arrs['y_pred'], axis=0).astype('float64') results_arrs['y_pred_entropy'] = vit_utils.entropy( np.concatenate(results_arrs['y_pred_entropy'], axis=0), axis=-1) time_elapsed = time.time() - start_time results_arrs['total_ms_elapsed'] = time_elapsed * 1e3 results_arrs['dataset_size'] = eval_steps * batch_size_eval all_eval_results[eval_name] = results_arrs per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions( # pylint: disable=unused-variable dataset_split_to_containers=all_eval_results, is_deterministic=True, num_bins=15, return_per_pred_results=True ) # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each # dataset. Flatten this dict so we can pass to the writer and remove empty # entries. flattened_metric_results = {} for dic in metrics_results.values(): for key, value in dic.items(): if value is not None: flattened_metric_results[key] = value writer.write_scalars(step, flattened_metric_results) # Optionally log to wandb if config.use_wandb: wandb.log(metrics_results, step=step) # Save per-prediction metrics results_storage_utils.save_per_prediction_results( output_dir, step, per_pred_results, verbose=False) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close() if wandb_run is not None: wandb_run.finish()
def eval_dataset_and_unshard(viewdir_mlp_model, viewdir_mlp_params, rgb_features, directions, source_dataset, scene_params): """Evaluates view-dependence on a sharded dataset and unshards the result. This function evaluates the view-dependence MLP on a dataset, adding back effects such as highlights. To make use of multi-host parallelism provided by JAX, this function takes as input a shardeds dataset, so each host only evaluates a slice of the data. Note that this function unshards the data before returning, which broadcasts the results back to all JAX hosts. Args: viewdir_mlp_model: A nerf.model_utils.MLP that predicts the per-ray view-dependent residual color. viewdir_mlp_params: A dict containing the MLP parameters for the per-ray view-dependence MLP. rgb_features: The RGB (+ features) input data, stored as an (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 7) numpy array. directions: he direction vectors for the input data, stored as an (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 3) numpy array. source_dataset: The nerf.datasets.Dataset we are evaluating. scene_params: A dict for scene specific params (bbox, rotation, resolution). Returns: A list of color images, each stored as a (H, W, 3) numpy array. """ @functools.partial(jax.pmap, in_axes=(0, 0), axis_name="batch") def pmap_eval_fn(rgb_and_feature_chunk, direction_chunk): """We need an inner function as only JAX types can be passed to a pmap.""" residual = model_utils.viewdir_fn(viewdir_mlp_model, viewdir_mlp_params, rgb_and_feature_chunk, direction_chunk, scene_params) output = jnp.minimum(1.0, rgb_and_feature_chunk[Ellipsis, 0:3] + residual) return jax.lax.all_gather(output, axis_name="batch") num_hosts = jax.host_count() num_local_devices = jax.local_device_count() num_images = source_dataset.camtoworlds.shape[0] num_batches = math.ceil(num_images / num_hosts) num_batches = num_local_devices * math.ceil( num_batches / num_local_devices) outputs = [] for i in range(len(rgb_features)): # First, evaluate the loss in parallel across all devices. output_batch = pmap_eval_fn(rgb_features[i], directions[i]) output_batch = np.reshape(output_batch[0], (num_hosts, num_local_devices, source_dataset.h, source_dataset.w, 3)) # Then, make sure to populate the output array in the same order # as the original dataset. for j in range(num_local_devices): base_index = (i * num_local_devices + j) * num_hosts for k in range(num_hosts): gathered_dataset_index = base_index + k if gathered_dataset_index >= num_images: break outputs.append( np.array(output_batch[k][j]).reshape( (source_dataset.h, source_dataset.w, 3))) return outputs
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() batch_size = FLAGS.batch_size learning_rate = FLAGS.learning_rate num_train_steps = FLAGS.num_train_steps num_eval_steps = FLAGS.num_eval_steps eval_freq = FLAGS.eval_frequency max_target_length = FLAGS.max_target_length max_eval_target_length = FLAGS.max_eval_target_length random_seed = FLAGS.random_seed if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') train_ds, eval_ds, info_ds = input_pipeline.get_lm1b_datasets( n_devices=jax.local_device_count(), data_dir=FLAGS.data_dir, batch_size=batch_size, dynamic_batching=True, max_target_length=max_target_length, max_eval_target_length=max_eval_target_length) vocab_size = info_ds['text'].encoder.vocab_size encoder = info_ds['text'].encoder train_iter = iter(train_ds) input_shape = (batch_size, max_target_length) transformer_lm_kwargs = { 'vocab_size': vocab_size, 'emb_dim': 512, 'num_heads': 8, 'num_layers': 6, 'qkv_dim': 512, 'mlp_dim': 2048, 'max_len': max(max_target_length, max_eval_target_length) } rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = random.split(rng) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) model, cache_def = create_model(init_rng, input_shape, transformer_lm_kwargs) optimizer = create_optimizer(model, learning_rate) del model # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') p_pred_step = jax.pmap(predict_step, axis_name='batch') metrics_all = [] tick = time.time() for step, batch in zip(range(start_step, num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % FLAGS.checkpoint_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and FLAGS.save_checkpoints: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % eval_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Eval Metrics eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip(jnp.exp( eval_summary['loss']), a_max=1.0e4) logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() # Fast inference of prompt extension using trained LM. rng, subrng = jax.random.split(rng) pred_rngs = random.split(subrng, jax.local_device_count()) prompt = jnp.array(encoder.encode(FLAGS.prompt)) prompt = jax_utils.replicate(prompt) prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1])) cache = jax_utils.replicate( cache_def.initialize_cache( (1, FLAGS.max_predict_token_length))) predicted = p_pred_step(prompt, optimizer.target, cache, pred_rngs) predicted = tohost(predicted) exemplars = '' for n in range(predicted.shape[0]): exemplars += encoder.decode(predicted[n]) + '\n\n' if jax.host_id() == 0: eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush()
def build_sharded_dataset_for_view_dependence(source_dataset, atlas_t, atlas_block_indices_t, atlas_params, scene_params, grid_params): """Builds a dataset that we can run the view-dependence MLP on. We ray march through a baked SNeRG model to generate images with RGB colors and features. These serve as the input for the view-dependence MLP which adds back the effects such as highlights. To make use of multi-host parallelism provided by JAX, this function shards the dataset, so that each host contains only a slice of the data. Args: source_dataset: The nerf.datasets.Dataset we should compute data for. atlas_t: A tensorflow tensor containing the texture atlas. atlas_block_indices_t: A tensorflow tensor containing the indirection grid. atlas_params: A dict with params for building and rendering with the 3D texture atlas. scene_params: A dict for scene specific params (bbox, rotation, resolution). grid_params: A dict with parameters describing the high-res voxel grid which the atlas is representing. Returns: rgb_data: The RGB (+ features) input data, stored as an (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 7) numpy array. alpha_data: The alpha channel of the input data, stored as an (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 1) numpy array. direction_data: The direction vectors for the input data, stored as an (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 3) numpy array. ref_data: The reference RGB colors for each input data sample, stored as an (N/NUM_HOSTS, NUM_LOCAL_DEVICES, H, W, 3) numpy array. """ num_hosts = jax.host_count() num_local_devices = jax.local_device_count() host_id = jax.host_id() num_images = source_dataset.camtoworlds.shape[0] num_batches = math.ceil(num_images / num_hosts) num_batches = num_local_devices * math.ceil( num_batches / num_local_devices) rgb_list = [] alpha_list = [] viewdir_list = [] ref_list = [] for i in range(num_batches): base_index = i * num_hosts dataset_index = base_index + host_id rgb = np.zeros( (source_dataset.h, source_dataset.w, scene_params["_channels"]), dtype=np.float32) alpha = np.zeros((source_dataset.h, source_dataset.w, 1), dtype=np.float32) viewdirs = np.zeros((source_dataset.h, source_dataset.w, 3), dtype=np.float32) if dataset_index < num_images: rgb, alpha = rendering.atlas_raymarch_image_tf( source_dataset.h, source_dataset.w, source_dataset.focal, source_dataset.camtoworlds[dataset_index], atlas_t, atlas_block_indices_t, atlas_params, scene_params, grid_params) _, _, viewdirs = datasets.rays_from_camera( scene_params["_use_pixel_centers"], source_dataset.h, source_dataset.w, source_dataset.focal, np.expand_dims(source_dataset.camtoworlds[dataset_index], 0)) np_rgb = np.array(rgb).reshape( (source_dataset.h, source_dataset.w, scene_params["_channels"])) np_alpha = np.array(alpha).reshape( (source_dataset.h, source_dataset.w, 1)) np_viewdirs = viewdirs.reshape((np_rgb.shape[0], np_rgb.shape[1], 3)) if scene_params["white_bkgd"]: np_rgb[Ellipsis, 0:3] = np.ones_like(np_rgb[Ellipsis, 0:3]) * ( 1.0 - np_alpha) + np_rgb[Ellipsis, 0:3] rgb_list.append(np_rgb) alpha_list.append(np_alpha) viewdir_list.append(np_viewdirs) ref_list.append(source_dataset.images[dataset_index % num_images]) rgb_data = np.stack(rgb_list, 0).reshape( (-1, num_local_devices, source_dataset.h, source_dataset.w, scene_params["_channels"])) alpha_data = np.stack(alpha_list, 0).reshape( (-1, num_local_devices, source_dataset.h, source_dataset.w, 1)) viewdir_data = np.stack(viewdir_list, 0).reshape( (-1, num_local_devices, source_dataset.h, source_dataset.w, 3)) ref_data = np.stack(ref_list, 0).reshape( (-1, num_local_devices, source_dataset.h, source_dataset.w, 3)) return rgb_data, alpha_data, viewdir_data, ref_data
def train_and_evaluate(random_seed, batch_size, learning_rate, num_train_steps, num_eval_steps, eval_freq, max_target_length, max_eval_target_length, weight_decay, data_dir, model_dir, restore_checkpoints, save_checkpoints, checkpoint_freq, max_predict_token_length, sampling_temperature, sampling_top_k, prompt_str): """Executes model training and evaluation loop. Args: random_seed: Seed for initializing PRNG random seed. batch_size: Batch size for training. learning_rate: Learning rate for the Adam optimizer. num_train_steps: Number of training steps. num_eval_steps: Number of evaluation steps. eval_freq: Frequency of evaluation during training. max_target_length: Maximum length of training examples. max_eval_target_length: Maximum length of eval examples. weight_decay: Decay factor for AdamW-style weight decay. data_dir: Directory containing TFDS lm1b/subwords32k datasets. model_dir: Directory where to store model data. restore_checkpoints: Whether to restore from existing model checkpoints. save_checkpoints: Whether to save model checkpoints. checkpoint_freq: Save a checkpoint every these number of steps. max_predict_token_length: Maximum example text inference token length. sampling_temperature: Sampling temperature for language model inference. sampling_top_k: Top k cutoff for logit sampling. prompt_str: Prompt for language model sampling. """ if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(model_dir, 'eval')) if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') train_ds, eval_ds, info_ds = input_pipeline.get_lm1b_datasets( n_devices=jax.local_device_count(), data_dir=data_dir, batch_size=batch_size, dynamic_batching=True, max_target_length=max_target_length, max_eval_target_length=max_eval_target_length) vocab_size = info_ds['text'].encoder.vocab_size encoder = info_ds['text'].encoder train_iter = iter(train_ds) input_shape = (batch_size, max_target_length) transformer_lm_kwargs = { 'vocab_size': vocab_size, 'emb_dim': 512, 'num_heads': 8, 'num_layers': 6, 'qkv_dim': 512, 'mlp_dim': 2048, 'max_len': max(max_target_length, max_eval_target_length) } rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = random.split(rng) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) model, cache_def = create_model(init_rng, input_shape, transformer_lm_kwargs) optimizer = create_optimizer(model, learning_rate, weight_decay) del model # Don't keep a copy of the initial model. start_step = 0 if restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') p_pred_step = jax.pmap(predict_step, axis_name='batch') metrics_all = [] tick = time.time() for step, batch in zip(range(start_step, num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % checkpoint_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and save_checkpoints: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint(model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % eval_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Eval Metrics eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip(jnp.exp( eval_summary['loss']), a_max=1.0e4) logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() # Fast inference of prompt extension using trained LM. rng, subrng = jax.random.split(rng) pred_rngs = random.split(subrng, jax.local_device_count()) prompt = jnp.array(encoder.encode(prompt_str)) prompt = jax_utils.replicate(prompt) prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1])) cache = jax_utils.replicate( cache_def.initialize_cache((1, max_predict_token_length))) predicted = p_pred_step(prompt, optimizer.target, cache, pred_rngs, max_predict_token_length, sampling_temperature, sampling_top_k) predicted = tohost(predicted) exemplars = '' for n in range(predicted.shape[0]): exemplars += encoder.decode(predicted[n]) + '\n\n' if jax.host_id() == 0: eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush()
def train( runner, dataset_paths=gin.REQUIRED, prefetch=4, batch_size_per_device=gin.REQUIRED, validation_example_count=gin.REQUIRED, ): """Train the maze automaton. Args: runner: Helper object that runs the experiment. dataset_paths: Dictionary of dataset paths, with keys: - "train_dataset": Path to training dataset files. - "eval_dataset": Path to validation dataset files. prefetch: Maximum number of examples to prefetch in a background thread. batch_size_per_device: Batch size for each device. validation_example_count: How many examples to use when computing validation metrics. Returns: Optimizer at the end of training (for interactive debugging). """ num_devices = jax.local_device_count() logging.info("Found %d devices: %s", num_devices, jax.devices()) with contextlib.ExitStack() as exit_stack: logging.info("Setting up datasets...") raw_train_iterator = runner.build_sampling_iterator( dataset_paths["train_dataset"], example_type=graph_bundle.GraphBundle) raw_valid_iterator_factory = runner.build_one_pass_iterator_factory( dataset_paths["eval_dataset"], example_type=graph_bundle.GraphBundle, truncate_at=validation_example_count) # Add the example id into the example itself, so that we can use it to # randomly choose a goal. def reify_id(it): for item in it: yield dataclasses.replace(item, example=(item.example, item.example_id)) def reify_id_and_batch(it): return data_loading.batch(reify_id(it), (num_devices, batch_size_per_device), remainder_behavior=data_loading. BatchRemainderBehavior.PAD_ZERO) train_iterator = reify_id_and_batch(raw_train_iterator) valid_iterator_factory = ( lambda: reify_id_and_batch(raw_valid_iterator_factory())) if prefetch: train_iterator = exit_stack.enter_context( data_loading.ThreadedPrefetcher(train_iterator, prefetch)) logging.info("Setting up model...") padding_config = maze_task.PADDING_CONFIG model_def = automaton_layer.FiniteStateGraphAutomaton.partial( static_metadata=padding_config.static_max_metadata, builder=maze_task.BUILDER) # Initialize parameters randomly. _, initial_params = model_def.init( jax.random.PRNGKey(int(time.time() * 1000)), graph_bundle.zeros_like_padded_example( padding_config).automaton_graph, dynamic_metadata=padding_config.static_max_metadata) model = flax.nn.Model(model_def, initial_params) optimizer = flax.optim.Adam().create(model) extra_artifacts = { "builder.pickle": maze_task.BUILDER, } return runner.training_loop( optimizer=optimizer, train_iterator=train_iterator, loss_fn=loss_fn, validation_fn=train_util.build_averaging_validator( loss_fn, valid_iterator_factory), extra_artifacts=extra_artifacts)
def main(config, output_dir): seed = config.get('seed', 0) tf.random.set_seed(seed) if config.get('data_dir'): logging.info('data_dir=%s', config.data_dir) logging.info('Output dir: %s', output_dir) tf.io.gfile.makedirs(output_dir) # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d devices per host (%d devices total), that\'s a %d per-device ' 'batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing val dataset(s)...') def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_utils.get_num_examples( dataset, split=split, process_batch_size=local_batch_size_eval, drop_remainder=False, data_dir=data_dir) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info('Running validation for %d steps for %s, %s', val_steps, dataset, split) if isinstance(pp_eval, str): pp_eval = preprocess_spec.parse( spec=pp_eval, available_ops=preprocess_utils.all_ops()) val_ds = input_utils.get_data(dataset=dataset, split=split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=pp_eval, cache=config.get('val_cache', 'batched'), num_epochs=1, repeat_after_batching=True, shuffle=False, prefetch_size=config.get( 'prefetch_to_host', 2), drop_remainder=False, data_dir=data_dir) return val_ds val_ds_splits = { 'val': _get_val_split(config.dataset, split=config.val_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) } if config.get('test_split'): val_ds_splits.update({ 'test': _get_val_split(config.dataset, split=config.test_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) }) if config.get('eval_on_cifar_10h'): cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn( config.get('data_dir', None)) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_cifar_10h, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex)) val_ds_splits['cifar_10h'] = _get_val_split( 'cifar10', split=config.get('cifar_10h_split') or 'test', pp_eval=pp_eval, data_dir=config.get('data_dir')) elif config.get('eval_on_imagenet_real'): imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn( ) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_imagenet_real, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex)) val_ds_splits['imagenet_real'] = _get_val_split( 'imagenet2012_real', split=config.get('imagenet_real_split') or 'validation', pp_eval=pp_eval, data_dir=config.get('data_dir')) ood_ds = {} if config.get('ood_datasets') and config.get('ood_methods'): if config.get( 'ood_methods'): # config.ood_methods is not a empty list logging.info('loading OOD dataset = %s', config.get('ood_datasets')) ood_ds, ood_ds_names = ood_utils.load_ood_datasets( config.dataset, config.ood_datasets, config.ood_split, config.pp_eval, config.pp_eval_ood, config.ood_methods, config.train_split, config.get('data_dir'), _get_val_split, ) write_note('Initializing model...') logging.info('config.model = %s', config.model) model = ub.models.vision_transformer(num_classes=config.num_classes, **config.model) ensemble_pred_fn = functools.partial(ensemble_prediction_fn, model.apply) @functools.partial(jax.pmap, axis_name='batch') def evaluation_fn(params, images, labels, mask): # params is a dict of the form: # {'model_1': params_model_1, 'model_2': params_model_2, ...} # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) loss_as_str = config.get('loss', 'sigmoid_xent') ens_logits, ens_prelogits = ensemble_pred_fn(params, images, loss_as_str) label_indices = config.get('label_indices') logging.info('!!! mask %s, label_indices %s', mask, label_indices) if label_indices: ens_logits = ens_logits[:, label_indices] # Note that logits and labels are usually of the shape [batch,num_classes]. # But for OOD data, when num_classes_ood > num_classes_ind, we need to # adjust labels to labels[:, :config.num_classes] to match the shape of # logits. That is just to avoid shape mismatch. The output losses does not # have any meaning for OOD data, because OOD not belong to any IND class. losses = getattr(train_utils, loss_as_str)( logits=ens_logits, labels=labels[:, :( len(label_indices) if label_indices else config.num_classes)], reduction=False) loss = jax.lax.psum(losses * mask, axis_name='batch') top1_idx = jnp.argmax(ens_logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch') n = jax.lax.psum(mask, axis_name='batch') metric_args = jax.lax.all_gather( [ens_logits, labels, ens_prelogits, mask], axis_name='batch') return ncorrect, loss, n, metric_args @functools.partial(jax.pmap, axis_name='batch') def cifar_10h_evaluation_fn(params, images, labels, mask): loss_as_str = config.get('loss', 'softmax_xent') ens_logits, ens_prelogits = ensemble_pred_fn(params, images, loss_as_str) label_indices = config.get('label_indices') if label_indices: ens_logits = ens_logits[:, label_indices] losses = getattr(train_utils, loss_as_str)(logits=ens_logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(ens_logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [ens_logits, labels, ens_prelogits, mask], axis_name='batch') return ncorrect, loss, n, metric_args # Setup function for computing representation. @functools.partial(jax.pmap, axis_name='batch') def representation_fn(params, images, labels, mask): # Return shape [batch_size, representation_size * ensemble_size]. During # few-shot eval, a single linear regressor is applied over all dimensions. representation = [] for p in params.values(): _, outputs = model.apply({'params': flax.core.freeze(p)}, images, train=False) representation += [outputs[config.fewshot.representation_layer]] representation = jnp.concatenate(representation, axis=1) representation = jax.lax.all_gather(representation, 'batch') labels = jax.lax.all_gather(labels, 'batch') mask = jax.lax.all_gather(mask, 'batch') return representation, labels, mask write_note('Load checkpoints...') ensemble_params = load_checkpoints(config) write_note('Replicating...') ensemble_params = flax.jax_utils.replicate(ensemble_params) if jax.process_index() == 0: writer.write_hparams(dict(config)) write_note('Initializing few-shotters...') fewshotter = None if 'fewshot' in config and fewshot is not None: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get('batch_size') or batch_size_eval) # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()} fewshot_results = {'dummy': {(0, 1): -jnp.inf}} step = 1 # Report validation performance. write_note('Evaluating on the validation set...') for val_name, val_ds in val_ds_splits.items(): # Sets up evaluation metrics. ece_num_bins = config.get('ece_num_bins', 15) auc_num_bins = config.get('auc_num_bins', 1000) ece = rm.metrics.ExpectedCalibrationError(num_bins=ece_num_bins) calib_auc = rm.metrics.CalibrationAUC(correct_pred_as_pos_label=False) oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.005, num_bins=auc_num_bins) oc_auc_1 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.01, num_bins=auc_num_bins) oc_auc_2 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.02, num_bins=auc_num_bins) oc_auc_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.05, num_bins=auc_num_bins) label_diversity = tf.keras.metrics.Mean() sample_diversity = tf.keras.metrics.Mean() ged = tf.keras.metrics.Mean() # Runs evaluation loop. val_iter = input_utils.start_input_pipeline( val_ds, config.get('prefetch_to_device', 1)) ncorrect, loss, nseen = 0, 0, 0 for batch in val_iter: if val_name == 'cifar_10h': batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( cifar_10h_evaluation_fn(ensemble_params, batch['image'], batch['labels'], batch['mask'])) else: batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( evaluation_fn(ensemble_params, batch['image'], batch['labels'], batch['mask'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': # Here we parse batch_metric_args to compute uncertainty metrics. # (e.g., ECE or Calibration AUC). logits, labels, _, masks = batch_metric_args masks = np.array(masks[0], dtype=np.bool) logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels, as required by ECE. int_labels = np.argmax(np.array(labels[0]), axis=-1) int_preds = np.argmax(logits, axis=-1) confidence = np.max(probs, axis=-1) for p, c, l, d, m, label in zip(probs, confidence, int_labels, int_preds, masks, labels[0]): ece.add_batch(p[m, :], label=l[m]) calib_auc.add_batch(d[m], label=l[m], confidence=c[m]) # TODO(jereliu): Extend to support soft multi-class probabilities. oc_auc_0_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_1.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_2.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) if val_name == 'cifar_10h' or val_name == 'imagenet_real': batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance( label[m], p[m, :], config.num_classes) label_diversity.update_state(batch_label_diversity) sample_diversity.update_state(batch_sample_diversity) ged.update_state(batch_ged) val_loss[val_name] = loss / nseen # Keep for reproducibility tests. val_measurements = { f'{val_name}_prec@1': ncorrect / nseen, f'{val_name}_loss': val_loss[val_name], } if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent': val_measurements[f'{val_name}_ece'] = ece.result()['ece'] val_measurements[f'{val_name}_calib_auc'] = calib_auc.result( )['calibration_auc'] val_measurements[f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result( )['collaborative_auc'] val_measurements[f'{val_name}_oc_auc_1%'] = oc_auc_1.result( )['collaborative_auc'] val_measurements[f'{val_name}_oc_auc_2%'] = oc_auc_2.result( )['collaborative_auc'] val_measurements[f'{val_name}_oc_auc_5%'] = oc_auc_5.result( )['collaborative_auc'] writer.write_scalars(step, val_measurements) if val_name == 'cifar_10h' or val_name == 'imagenet_real': cifar_10h_measurements = { f'{val_name}_label_diversity': label_diversity.result(), f'{val_name}_sample_diversity': sample_diversity.result(), f'{val_name}_ged': ged.result(), } writer.write_scalars(step, cifar_10h_measurements) # OOD eval # Entries in the ood_ds dict include: # (ind_dataset, ood_dataset1, ood_dataset2, ...). # OOD metrics are computed using ind_dataset paired with each of the # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds # is also included in the ood_ds. if ood_ds and config.ood_methods: ood_measurements = ood_utils.eval_ood_metrics(ood_ds, ood_ds_names, config.ood_methods, evaluation_fn, ensemble_params, n_prefetch=config.get( 'prefetch_to_device', 1)) writer.write_scalars(step, ood_measurements) if 'fewshot' in config and fewshotter is not None: # Compute few-shot on-the-fly evaluation. write_note('Few-shot evaluation...') # Keep `results` to return for reproducibility tests. fewshot_results, best_l2 = fewshotter.run_all(ensemble_params, config.fewshot.datasets) # TODO(dusenberrymw): Remove this once fewshot.py is updated. def make_writer_measure_fn(step): def writer_measure(name, value): writer.write_scalars(step, {name: value}) return writer_measure fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results, best_l2) write_note('Done!') pool.close() pool.join() writer.close() # Return final training loss, validation loss, and fewshot results for # reproducibility test cases. return val_loss, fewshot_results
def parallel_train_loop(key, init_params, loss_fn, summarize_fn=default_summarize, lr=1e-4, num_steps=int(1e5), summarize_every=100, checkpoint_every=5000, clobber_checkpoint=False, logdir="/tmp/lda_inference"): loss_fn = jax.jit(loss_fn) optimizer_def = optim.Adam() local_optimizer = optimizer_def.create(init_params) local_optimizer = util.maybe_load_checkpoint( logdir, local_optimizer, clobber_checkpoint=clobber_checkpoint) first_step = local_optimizer.state.step repl_optimizer = jax_utils.replicate(local_optimizer) lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr) @functools.partial(jax.pmap, axis_name="batch") def train_step(optimizer, key): key, subkey = jax.random.split(key) loss_grad = jax.grad(loss_fn, argnums=0)(optimizer.target, key) loss_grad = jax.lax.pmean(loss_grad, "batch") new_optimizer = optimizer.apply_gradient( loss_grad, learning_rate=lr_fn(optimizer.state.step)) return new_optimizer, subkey sw = SummaryWriter(logdir) repl_key = jax.pmap(jax.random.PRNGKey)(jnp.arange(jax.local_device_count())) start = timeit.default_timer() for t in range(first_step, num_steps): if t % checkpoint_every == 0 and t != first_step: optimizer = jax_utils.unreplicate(repl_optimizer) checkpoints.save_checkpoint(logdir, optimizer, optimizer.state.step, keep=3) print("Checkpoint saved for step %d" % optimizer.state.step) repl_optimizer, repl_key = train_step(repl_optimizer, repl_key) if t % summarize_every == 0: key, subkey = jax.random.split(jax_utils.unreplicate(repl_key)) optimizer = jax_utils.unreplicate(repl_optimizer) loss_val = loss_fn(optimizer.target, key) print("Step %d loss: %0.4f" % (t, loss_val)) sw.scalar("loss", loss_val, step=t) summarize_fn(sw, t, optimizer.target, subkey) end = timeit.default_timer() if t == 0: steps_per_sec = 1. / (end - start) else: steps_per_sec = summarize_every / (end - start) print("Steps/sec: %0.2f" % steps_per_sec) sw.scalar("steps_per_sec", steps_per_sec, step=t) start = end sw.flush() sys.stdout.flush()
def run_train(run_configuration): """Runs the training workflow.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter log_dir = os.path.join(run_dir, 'train') checkpoint_path = run_configuration.original_checkpoint_path dataset = run_configuration.dataset_info.dataset info = run_configuration.dataset_info.info random_seed = 0 rng = jax.random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) dropout_rngs = jax.random.split(rng, jax.local_device_count()) # Set up optimizer. optimizer = adapter.create_optimizer(run_configuration, rng=init_rng) # Set up train step. train_step = adapter.make_train_step() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(log_dir) # Set up checkpointing. # TODO(dbieber): Set up phoenix. checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir) if checkpoint_path is None: checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir) optimizer = checkpoint_utils.handle_restart_behavior( checkpoint_path, optimizer, config) start_step = int(optimizer.state.step) num_train_steps = config.train.total_steps # Replicate optimizer. optimizer = flax.jax_utils.replicate(optimizer) # Begin training loop. dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw) summary_freq = config.logging.summary_freq metrics_all = [] tick = time.time() for step, example in zip(range(start_step, num_train_steps), dataset_iter): train_inputs = adapter.get_train_inputs(example) optimizer, metrics, dropout_rngs, logits, state = train_step( optimizer, train_inputs, dropout_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % config.logging.save_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.logging.save_freq: # Save unreplicated optimizer + model state. checkpoint_utils.save_checkpoint( checkpoint_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if summary_freq and step % summary_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = summary_freq / (tock - tick) examples_per_sec = denominator / (tock - tick) tick = tock summary_writer.scalar('per-second/steps', steps_per_sec, step) summary_writer.scalar('per-second/examples', examples_per_sec, step) for key, val in summary.items(): summary_writer.scalar(key, val, step) adapter.write_summaries(example, logits, summary_writer, info, step, state) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = []
def predict_and_compute_score(*, p_pred_step, p_init_cache, target, predict_ds, decode_io, decode_program, beam_size, num_partial_programs, use_best_first_search = False, slow_decode = False): """Generates program and computes score.""" n_devices = jax.local_device_count() pred_acc = 0 pred_denominator = 0 ios, targets, predictions = [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, programs, split_outputs = common_utils.shard(pred_batch) cache = (p_init_cache(inputs, outputs, programs[:, :, 0]) if not slow_decode else None) predicted, log_probs = p_pred_step(target, inputs, outputs, cache, beam_size, split_outputs=split_outputs) predicted, log_probs = map(tohost, (predicted, log_probs)) inputs, outputs, programs = map(tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, partial_beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) # Find the best orderings of partial programs. # partial_seqs shape == [n_beam, n_partial] if use_best_first_search: partial_seqs = best_first_search(log_probs[i], beam_size) else: partial_seqs = beam_decoder(log_probs[i], beam_size) # beams shape == [n_beam, n_partial, length] beams = partial_beams[np.arange(num_partial_programs), partial_seqs] # Execute predicted programs on i/o examples. p, p_score = compute_score(beams, inps, outs, decode_program) if p_score >= len(inps): pred_acc += 1 ios.append(' ; '.join(map(str, zip(inps, outs)))) targets.append(decode_program(programs[i]).to_string()) try: predictions.append(p.to_string()) except: # pylint: disable=bare-except predictions.append('') logging.info('ios: %s', ios[-1]) logging.info('target: %s', targets[-1]) beams_log = [] for beam in beams: try: beams_log.append(decode_program(beam).to_string()) except: # pylint: disable=bare-except beams_log.append('None') logging.info('predicted beam: %s', '\n'.join(beams_log)) all_pred_acc, all_pred_denominator = per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n') message.append(text) return all_pred_acc / all_pred_denominator, message
split=process_split, batch_dims=(), rng=rng, filter_fn=None, preprocess_fn=preprocess_fn, decoders={"image": tfds.decode.SkipDecoding()}, cache=cache == "loaded", num_epochs=num_epochs if not repeat_after_batching else 1, shuffle=shuffle, shuffle_buffer_size=shuffle_buffer_size, prefetch_size=0, pad_up_to_batches=None, drop_remainder=drop_remainder, ) num_devices = jax.local_device_count() if drop_remainder: # If we're dropping the remainder, we can take the fast path of double # batching to [num_devices, batch_size_per_device] and then adding a mask of # ones for the two batch dimensions. batch_size_per_device = process_batch_size // num_devices batch_dims = [num_devices, batch_size_per_device] for batch_size in reversed(batch_dims): dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.map(lambda xs: _add_mask(xs, 2), num_parallel_calls=tf.data.AUTOTUNE) else: # If we're not dropping the remainder, then we define a flattened batch size # that would divide evenly across devices, and then batch to that size with # drop_remainder=False. Then we add a mask of ones for the examples given,
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Grab pretrain text data if FLAGS.target_text: targets_decoded_pt = [] for i in range(1, 9): with tf.io.gfile.GFile(FLAGS.target_text % i, 'rb') as f: pt_targs_tmp = pickle.load(f) targets_decoded_pt.extend(pt_targs_tmp) else: train_ds, (encoder_in, encoder_tgt) = input_pipeline.get_wmt_is_datasets( n_devices=jax.local_device_count(), dataset_name=FLAGS.dataset_name, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=FLAGS.vocab_path, target_vocab_size=32000, batch_size=1024, max_length=256, paracrawl_size=FLAGS.paracrawl_size, split_tokenizer=FLAGS.split_tokenizer) train_data = iter(train_ds) eos_id = decode.EOS_ID def decode_tokens(encoder, toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') targets = [] inputs = [] for x in train_data: trg = x['targets']._numpy() # pylint:disable=protected-access ins = x['inputs']._numpy() # pylint:disable=protected-access targets.append(trg) inputs.append(ins) # flatten targets_decoded_pt # pylint:disable=g-complex-comprehension targets_flat = [t for batch_t in targets for t in batch_t] inputs_flat = [t for batch_t in inputs for t in batch_t] # pylint:enable=g-complex-comprehension # decode only the slice for this one targets_decoded_pt = [] start = PROC_SIZE * FLAGS.slice end = PROC_SIZE * (FLAGS.slice + 1) if FLAGS.slice == 14: end = 9999999 for i, x in enumerate(targets_flat[start:end]): if FLAGS.clf_inputs: input_decode = decode_tokens(encoder_in, inputs_flat[i + start]) if FLAGS.clf_targets: target_decode = decode_tokens(encoder_tgt, x) if FLAGS.clf_inputs and FLAGS.clf_targets: decode_tok = input_decode + ' [SEP] ' + target_decode else: decode_tok = target_decode if FLAGS.clf_targets else input_decode targets_decoded_pt.append(decode_tok) # Load model cache_dir = '/tmp/' # model weights get temporarily written to this directory path = FLAGS.bert_base_dir trained_path = FLAGS.bert_clf_dir config = transformers.BertConfig.from_pretrained(os.path.join( trained_path, 'config.json'), num_labels=2, cache_dir=cache_dir) tokenizer = transformers.BertTokenizer.from_pretrained(path, cache_dir=cache_dir) model = transformers.TFBertForSequenceClassification.from_pretrained( os.path.join(trained_path, 'tf_model.h5'), config=config, cache_dir=cache_dir) if FLAGS.target_text: # If we read the entire dataset from text, select the slice to encode start = PROC_SIZE * FLAGS.slice end = PROC_SIZE * (FLAGS.slice + 1) if FLAGS.slice == 14: end = 9999999 input_targets = targets_decoded_pt[start:end] else: # the targets were decoded above so just use the ones that were decoded input_targets = targets_decoded_pt encoding = tokenizer(input_targets, return_tensors='tf', padding=True, truncation=True, max_length=512) train_dataset = tf.data.Dataset.from_tensor_slices((dict(encoding), )) batch_size = 256 if FLAGS.clf_inputs and FLAGS.clf_targets: # multiling model is larger batch_size = 128 train_dataset = train_dataset.batch(batch_size) logits = model.predict(train_dataset) probs = softmax(logits.logits, axis=1) clf_score_name = FLAGS.save_dir + '/CLR_scores_' + str( FLAGS.slice) + '.csv' with tf.io.gfile.GFile(clf_score_name, 'w') as f: writer = csv.writer(f) for p in probs: writer.writerow([p[1]])
def barrier(): """MPI-like barrier.""" jax.device_get(_barrier(jnp.ones((jax.local_device_count(),))))
def create_buffers(self, name, param): """Prepares all momentum buffers for each parameter.""" state = {'step': jnp.zeros(jax.local_device_count())} if self.get_hyper(name, 'momentum') is not None: state['momentum'] = jnp.zeros_like(param) return state
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') FLAGS.alsologtostderr = True train_split = dataset.Split.from_string(FLAGS.train_split) eval_split = dataset.Split.from_string(FLAGS.eval_split) # The total batch size is the batch size accross all hosts and devices. In a # multi-host training setup each host will only see a batch size of # `total_train_batch_size / jax.host_count()`. total_train_batch_size = FLAGS.train_device_batch_size * jax.device_count() num_train_steps = ((train_split.num_examples * FLAGS.train_epochs) // total_train_batch_size) local_device_count = jax.local_device_count() train_dataset = dataset.load( train_split, is_training=True, batch_dims=[local_device_count, FLAGS.train_device_batch_size], bfloat16=FLAGS.train_bfloat16, transpose=FLAGS.dataset_transpose) if jax.default_backend() == 'gpu': # TODO(tomhennigan): This could be removed if XLA:GPU's allocator changes. train_dataset = dataset.double_buffer(train_dataset) # For initialization we need the same random key on each device. rng = jax.random.PRNGKey(FLAGS.train_init_random_seed) rng = jnp.broadcast_to(rng, (local_device_count, ) + rng.shape) # Initialization requires an example input. batch = next(train_dataset) params, state, opt_state = jax.pmap(make_initial_state)(rng, batch) # Print a useful summary of the execution of our module. summary = hk.experimental.tabulate(train_step)(params, state, opt_state, batch) for line in summary.split('\n'): logging.info(line) eval_every = FLAGS.train_eval_every log_every = FLAGS.train_log_every with time_activity('train'): for step_num in range(num_train_steps): # Take a single training step. with jax.profiler.StepTraceContext('train', step_num=step_num): params, state, opt_state, train_scalars = (train_step( params, state, opt_state, next(train_dataset))) # By default we do not evaluate during training, but you can configure # this with a flag. if eval_every > 0 and step_num and step_num % eval_every == 0: with time_activity('eval during train'): eval_scalars = evaluate(eval_split, params, state) logging.info('[Eval %s/%s] %s', step_num, num_train_steps, eval_scalars) # Log progress at fixed intervals. if step_num and step_num % log_every == 0: train_scalars = jax.tree_map(lambda v: np.mean(v).item(), jax.device_get(train_scalars)) logging.info('[Train %s/%s] %s', step_num, num_train_steps, train_scalars) # Once training has finished we run eval one more time to get final results. with time_activity('final eval'): eval_scalars = evaluate(eval_split, params, state) logging.info('[Eval FINAL]: %s', eval_scalars)
def main(_): tf.enable_v2_behavior() # make sure tf does not allocate gpu memory tf.config.experimental.set_visible_devices([], 'GPU') # Performance gains on TPU by switching to hardware bernoulli. def hardware_bernoulli(rng_key, p=jax.numpy.float32(0.5), shape=None): lax_key = jax.lax.tie_in(rng_key, 0.0) return jax.lax.rng_uniform(lax_key, 1.0, shape) < p def set_hardware_bernoulli(): jax.random.bernoulli = hardware_bernoulli set_hardware_bernoulli() # As we gridsearch the weight decay and the learning rate, we add them to the # output directory path so that each model has its own directory to save the # results in. We also add the `run_seed` which is "gridsearched" on to # replicate an experiment several times. output_dir_suffix = os.path.join('lr_' + str(FLAGS.learning_rate), 'wd_' + str(FLAGS.weight_decay), 'rho_' + str(FLAGS.sam_rho), 'seed_' + str(FLAGS.run_seed)) output_dir = os.path.join(FLAGS.output_dir, output_dir_suffix) if not gfile.exists(output_dir): gfile.makedirs(output_dir) num_devices = jax.local_device_count() * jax.host_count() assert FLAGS.batch_size % num_devices == 0 local_batch_size = FLAGS.batch_size // num_devices info = 'Total batch size: {} ({} x {} replicas)'.format( FLAGS.batch_size, local_batch_size, num_devices) logging.info(info) if FLAGS.dataset == 'cifar10': if FLAGS.from_pretrained_checkpoint: image_size = efficientnet.name_to_image_size(FLAGS.model_name) else: image_size = None dataset_source = dataset_source_lib.Cifar10( FLAGS.batch_size // jax.host_count(), FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations, image_size=image_size) elif FLAGS.dataset == 'cifar100': if FLAGS.from_pretrained_checkpoint: image_size = efficientnet.name_to_image_size(FLAGS.model_name) else: image_size = None dataset_source = dataset_source_lib.Cifar100( FLAGS.batch_size // jax.host_count(), FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations, image_size=image_size) elif FLAGS.dataset == 'fashion_mnist': dataset_source = dataset_source_lib.FashionMnist( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) elif FLAGS.dataset == 'svhn': dataset_source = dataset_source_lib.SVHN( FLAGS.batch_size, FLAGS.image_level_augmentations, FLAGS.batch_level_augmentations) elif FLAGS.dataset == 'imagenet': imagenet_image_size = efficientnet.name_to_image_size(FLAGS.model_name) dataset_source = dataset_source_imagenet.Imagenet( FLAGS.batch_size // jax.host_count(), imagenet_image_size, FLAGS.image_level_augmentations) else: raise ValueError('Dataset not recognized.') if 'cifar' in FLAGS.dataset or 'svhn' in FLAGS.dataset: if image_size is None or 'svhn' in FLAGS.dataset: image_size = 32 num_channels = 3 num_classes = 100 if FLAGS.dataset == 'cifar100' else 10 elif FLAGS.dataset == 'fashion_mnist': image_size = 28 # For Fashion Mnist num_channels = 1 num_classes = 10 elif FLAGS.dataset == 'imagenet': image_size = imagenet_image_size num_channels = 3 num_classes = 1000 else: raise ValueError('Dataset not recognized.') try: model, state = load_imagenet_model.get_model(FLAGS.model_name, local_batch_size, image_size, num_classes) except load_imagenet_model.ModelNameError: model, state = load_model.get_model(FLAGS.model_name, local_batch_size, image_size, num_classes, num_channels) # Learning rate will be overwritten by the lr schedule, we set it to zero. optimizer = flax_training.create_optimizer(model, 0.0) flax_training.train(optimizer, state, dataset_source, output_dir, FLAGS.num_epochs)
def main(config, output_dir): seed = config.get('seed', 0) rng = jax.random.PRNGKey(seed) tf.random.set_seed(seed) if config.get('dataset_dir'): logging.info('data_dir=%s', config.dataset_dir) logging.info('Output dir: %s', output_dir) save_checkpoint_path = None if config.get('checkpoint_steps'): gfile.makedirs(output_dir) save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz') # Create an asynchronous multi-metric writer. writer = metric_writers.create_default_writer( output_dir, just_logging=jax.process_index() > 0) # The pool is used to perform misc operations such as logging in async way. pool = multiprocessing.pool.ThreadPool() def write_note(note): if jax.process_index() == 0: logging.info('NOTE: %s', note) write_note('Initializing...') # Verify settings to make sure no checkpoints are accidentally missed. if config.get('keep_checkpoint_steps'): assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.' assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, ( f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be' f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`') batch_size = config.batch_size batch_size_eval = config.get('batch_size_eval', batch_size) if (batch_size % jax.device_count() != 0 or batch_size_eval % jax.device_count() != 0): raise ValueError( f'Batch sizes ({batch_size} and {batch_size_eval}) must ' f'be divisible by device number ({jax.device_count()})') local_batch_size = batch_size // jax.process_count() local_batch_size_eval = batch_size_eval // jax.process_count() logging.info( 'Global batch size %d on %d hosts results in %d local batch size. ' 'With %d devices per host (%d devices total), that\'s a %d per-device ' 'batch size.', batch_size, jax.process_count(), local_batch_size, jax.local_device_count(), jax.device_count(), local_batch_size // jax.local_device_count()) write_note('Initializing train dataset...') rng, train_ds_rng = jax.random.split(rng) train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index()) train_ds = input_utils.get_data( dataset=config.dataset, split=config.train_split, rng=train_ds_rng, process_batch_size=local_batch_size, preprocess_fn=preprocess_spec.parse( spec=config.pp_train, available_ops=preprocess_utils.all_ops()), shuffle_buffer_size=config.shuffle_buffer_size, prefetch_size=config.get('prefetch_to_host', 2), data_dir=config.get('data_dir')) # Start prefetching already. train_iter = input_utils.start_input_pipeline( train_ds, config.get('prefetch_to_device', 1)) write_note('Initializing val dataset(s)...') def _get_val_split(dataset, split, pp_eval, data_dir=None): # We do ceil rounding such that we include the last incomplete batch. nval_img = input_utils.get_num_examples( dataset, split=split, process_batch_size=local_batch_size_eval, drop_remainder=False, data_dir=data_dir) val_steps = int(np.ceil(nval_img / batch_size_eval)) logging.info('Running validation for %d steps for %s, %s', val_steps, dataset, split) if isinstance(pp_eval, str): pp_eval = preprocess_spec.parse( spec=pp_eval, available_ops=preprocess_utils.all_ops()) val_ds = input_utils.get_data(dataset=dataset, split=split, rng=None, process_batch_size=local_batch_size_eval, preprocess_fn=pp_eval, cache=config.get('val_cache', 'batched'), num_epochs=1, repeat_after_batching=True, shuffle=False, prefetch_size=config.get( 'prefetch_to_host', 2), drop_remainder=False, data_dir=data_dir) return val_ds val_ds_splits = { 'val': _get_val_split(config.dataset, config.val_split, config.pp_eval, config.get('data_dir')) } if config.get('test_split'): val_ds_splits.update({ 'test': _get_val_split(config.dataset, split=config.test_split, pp_eval=config.pp_eval, data_dir=config.get('data_dir')) }) if config.get('eval_on_cifar_10h'): cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn( config.get('data_dir', None)) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_cifar_10h, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex)) val_ds_splits['cifar_10h'] = _get_val_split( 'cifar10', split=config.get('cifar_10h_split') or 'test', pp_eval=pp_eval, data_dir=config.get('data_dir')) elif config.get('eval_on_imagenet_real'): imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn( ) preprocess_fn = preprocess_spec.parse( spec=config.pp_eval_imagenet_real, available_ops=preprocess_utils.all_ops()) pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex)) val_ds_splits['imagenet_real'] = _get_val_split( 'imagenet2012_real', split=config.get('imagenet_real_split') or 'validation', pp_eval=pp_eval, data_dir=config.get('data_dir')) ood_ds = None if config.get('ood_datasets') and config.get('ood_methods'): if config.get( 'ood_methods'): # config.ood_methods is not a empty list logging.info('loading OOD dataset = %s', config.get('ood_datasets')) ood_ds, ood_ds_names = ood_utils.load_ood_datasets( config.dataset, config.ood_datasets, config.ood_split, config.pp_eval, config.pp_eval_ood, config.ood_methods, config.train_split, config.get('data_dir'), _get_val_split, ) ntrain_img = input_utils.get_num_examples( config.dataset, split=config.train_split, process_batch_size=local_batch_size, data_dir=config.get('data_dir')) steps_per_epoch = int(ntrain_img / batch_size) if config.get('num_epochs'): total_steps = int(config.num_epochs * steps_per_epoch) assert not config.get( 'total_steps'), 'Set either num_epochs or total_steps' else: total_steps = config.total_steps logging.info('Total train data points: %d', ntrain_img) logging.info( 'Running for %d steps, that means %f epochs and %d steps per epoch', total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch) write_note('Initializing model...') logging.info('config.model = %s', config.get('model')) model = ub.models.het_vision_transformer(num_classes=config.num_classes, **config.get('model', {})) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @partial(jax.jit, backend='cpu') def init(rng): image_size = tuple(train_ds.element_spec['image'].shape[2:]) logging.info('image_size = %s', image_size) dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32) rng, diag_noise_rng, standard_noise_rng = jax.random.split(rng, num=3) init_rngs = { 'params': rng, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng } params = flax.core.unfreeze( model.init(init_rngs, dummy_input, train=False))['params'] head = ('multiclass_head' if config.get('model', {}).get('multiclass') else 'multilabel_head') # Set bias in the head to a low value, such that loss is small initially. if head in params: params[head]['loc_layer']['bias'] = jnp.full_like( params[head]['loc_layer']['bias'], config.get('init_head_bias', 0)) # init head kernel to all zeros for fine-tuning if config.get('model_init'): params[head]['loc_layer']['kernel'] = jnp.full_like( params[head]['loc_layer']['kernel'], 0) params[head]['diag_layer']['kernel'] = jnp.full_like( params[head]['diag_layer']['kernel'], 0) params[head]['diag_layer']['bias'] = jnp.full_like( params[head]['diag_layer']['bias'], 0) if 'scale_layer_homoscedastic' in params[head]: params[head]['scale_layer_homoscedastic'][ 'kernel'] = jnp.full_like( params[head]['scale_layer_homoscedastic']['kernel'], 0) params[head]['scale_layer_homoscedastic'][ 'bias'] = jnp.full_like( params[head]['scale_layer_homoscedastic']['bias'], 0) if 'scale_layer_heteroscedastic' in params[head]: params[head]['scale_layer_heteroscedastic'][ 'kernel'] = jnp.full_like( params[head]['scale_layer_heteroscedastic']['kernel'], 0) params[head]['scale_layer_heteroscedastic'][ 'bias'] = jnp.full_like( params[head]['scale_layer_heteroscedastic']['bias'], 0) return params (rng, rng_init, rng_dropout, diag_noise_rng, standard_noise_rng) = jax.random.split(rng, num=5) params_cpu = init(rng_init) if jax.process_index() == 0: num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0]) parameter_overview.log_parameter_overview(params_cpu) writer.write_scalars(step=0, scalars={'num_params': num_params}) @partial(jax.pmap, axis_name='batch') def evaluation_fn(params, images, labels, mask): # Ignore the entries with all zero labels for evaluation. mask *= labels.max(axis=1) logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False, rngs={ 'dropout': rng_dropout, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng }) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] # Note that logits and labels are usually of the shape [batch,num_classes]. # But for OOD data, when num_classes_ood > num_classes_ind, we need to # adjust labels to labels[:, :config.num_classes] to match the shape of # logits. That is just to avoid shape mismatch. The output losses does not # have any meaning for OOD data, because OOD not belong to any IND class. losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))( logits=logits, labels=labels[:, :( len(label_indices) if label_indices else config.num_classes)], reduction=False) loss = jax.lax.psum(losses * mask, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch') n = jax.lax.psum(mask, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args @partial(jax.pmap, axis_name='batch') def cifar_10h_evaluation_fn(params, images, labels, mask): logits, out = model.apply({'params': flax.core.freeze(params)}, images, train=False, rngs={ 'dropout': rng_dropout, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng }) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(logits=logits, labels=labels, reduction=False) loss = jax.lax.psum(losses, axis_name='batch') top1_idx = jnp.argmax(logits, axis=1) # Extracts the label at the highest logit index for each image. one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)] top1_correct = jnp.take_along_axis(one_hot_labels, top1_idx[:, None], axis=1)[:, 0] ncorrect = jax.lax.psum(top1_correct, axis_name='batch') n = jax.lax.psum(one_hot_labels, axis_name='batch') metric_args = jax.lax.all_gather( [logits, labels, out['pre_logits'], mask], axis_name='batch') return ncorrect, loss, n, metric_args # Setup function for computing representation. @partial(jax.pmap, axis_name='batch') def representation_fn(params, images, labels, mask): _, outputs = model.apply({'params': flax.core.freeze(params)}, images, train=False, rngs={ 'dropout': rng_dropout, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng }) representation = outputs[config.fewshot.representation_layer] representation = jax.lax.all_gather(representation, 'batch') labels = jax.lax.all_gather(labels, 'batch') mask = jax.lax.all_gather(mask, 'batch') return representation, labels, mask # Load the optimizer from flax. opt_name = config.get('optim_name') write_note(f'Initializing {opt_name} optimizer...') opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {})) # We jit this, such that the arrays that are created are created on the same # device as the input is, in this case the CPU. Else they'd be on device[0]. opt_cpu = jax.jit(opt_def.create)(params_cpu) weight_decay_rules = config.get('weight_decay', []) or [] rescale_value = config.lr.base if config.get( 'weight_decay_decouple') else 1. weight_decay_fn = train_utils.get_weight_decay_fn( weight_decay_rules=weight_decay_rules, rescale_value=rescale_value) @partial(jax.pmap, axis_name='batch', donate_argnums=(0, )) def update_fn(opt, lr, images, labels, rng): """Update step.""" measurements = {} # Get device-specific loss rng. rng, rng_model = jax.random.split(rng, 2) rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch')) rng_model_local, diag_noise_rng, standard_noise_rng = jax.random.split( rng_model_local, num=3) def loss_fn(params, images, labels): logits, _ = model.apply({'params': flax.core.freeze(params)}, images, train=True, rngs={ 'dropout': rng_model_local, 'diag_noise_samples': diag_noise_rng, 'standard_norm_noise_samples': standard_noise_rng }) label_indices = config.get('label_indices') if label_indices: logits = logits[:, label_indices] return getattr(train_utils, config.get('loss', 'sigmoid_xent'))(logits=logits, labels=labels) # Implementation considerations compared and summarized at # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en# l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target, images, labels, config.get('grad_accum_steps')) l, g = jax.lax.pmean((l, g), axis_name='batch') # Log the gradient norm only if we need to compute it anyways (clipping) # or if we don't use grad_accum_steps, as they interact badly. if config.get('grad_accum_steps', 1) == 1 or config.get('grad_clip_norm'): grads, _ = jax.tree_flatten(g) l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads])) measurements['l2_grads'] = l2_g # Optionally resize the global gradient to a maximum norm. We found this # useful in some cases across optimizers, hence it's in the main loop. if config.get('grad_clip_norm'): g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g) g = jax.tree_map(lambda p: g_factor * p, g) opt = opt.apply_gradient(g, learning_rate=lr) opt = opt.replace(target=weight_decay_fn(opt.target, lr)) params, _ = jax.tree_flatten(opt.target) measurements['l2_params'] = jnp.sqrt( sum([jnp.vdot(p, p) for p in params])) return opt, l, rng, measurements if config.get('model.multiclass', False): default_reinit_params = [] else: default_reinit_params = [ 'head/scale_layer_homoscedastic/kernel', 'head/scale_layer_homoscedastic/bias', 'head/scale_layer_heteroscedastic/kernel', 'head/scale_layer_heteroscedastic/bias', 'head/loc_layer/kernel', 'head/diag_layer/kernel', 'head/loc_layer/bias', 'head/diag_layer/bias' ] default_reinit_params = ( default_reinit_params + list(map(lambda k: 'multilabel_' + k, default_reinit_params)) + list(map(lambda k: 'multiclass_' + k, default_reinit_params))) rng, train_loop_rngs = jax.random.split(rng) if config.get('only_eval', False) or not config.get('reint_head', True): default_reinit_params = [] checkpoint_data = checkpoint_utils.maybe_load_checkpoint( train_loop_rngs=train_loop_rngs, save_checkpoint_path=save_checkpoint_path, init_optimizer=opt_cpu, init_params=params_cpu, init_fixed_model_states=None, default_reinit_params=default_reinit_params, config=config, ) train_loop_rngs = checkpoint_data.train_loop_rngs opt_cpu = checkpoint_data.optimizer accumulated_train_time = checkpoint_data.accumulated_train_time write_note('Adapting the checkpoint model...') adapted_params = checkpoint_utils.adapt_upstream_architecture( init_params=params_cpu, loaded_params=opt_cpu.target) opt_cpu = opt_cpu.replace(target=adapted_params) write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. if first_step == 0 and jax.process_index() == 0: writer.write_hparams(dict(config)) chrono = train_utils.Chrono(first_step, total_steps, batch_size, accumulated_train_time) # Note: switch to ProfileAllHosts() if you need to profile all hosts. # (Xprof data become much larger and take longer to load for analysis) profiler = periodic_actions.Profile( # Create profile after every restart to analyze pre-emption related # problems and assure we get similar performance in every run. logdir=output_dir, first_profile=first_step + 10) # Prepare the learning-rate and pre-fetch it to device to avoid delays. lr_fn = train_utils.create_learning_rate_schedule(total_steps, **config.get('lr', {})) # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be # necessary for TPUs. lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1)) write_note(f'Replicating...\n{chrono.note}') opt_repl = flax_utils.replicate(opt_cpu) write_note(f'Initializing few-shotters...\n{chrono.note}') fewshotter = None if 'fewshot' in config and fewshot is not None: fewshotter = fewshot.FewShotEvaluator( representation_fn, config.fewshot, config.fewshot.get('batch_size') or batch_size_eval) checkpoint_writer = None # Note: we return the train loss, val loss, and fewshot best l2s for use in # reproducibility unit tests. train_loss = -jnp.inf val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()} fewshot_results = {'dummy': {(0, 1): -jnp.inf}} write_note(f'First step compilations...\n{chrono.note}') logging.info('first_step = %s', first_step) # Advance the iterators if we are restarting from an earlier checkpoint. # TODO(dusenberrymw): Look into checkpointing dataset state instead. if first_step > 0: write_note('Advancing iterators after resuming from a checkpoint...') lr_iter = itertools.islice(lr_iter, first_step, None) train_iter = itertools.islice(train_iter, first_step, None) # Using a python integer for step here, because opt.state.step is allocated # on TPU during replication. for step, train_batch, lr_repl in zip( range(first_step + 1, total_steps + 1), train_iter, lr_iter): with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1): if not config.get('only_eval', False): opt_repl, loss_value, train_loop_rngs, extra_measurements = update_fn( opt_repl, lr_repl, train_batch['image'], train_batch['labels'], rng=train_loop_rngs) if jax.process_index() == 0: profiler(step) # Checkpoint saving if not config.get('only_eval', False) and train_utils.itstime( step, config.get('checkpoint_steps'), total_steps, process=0): write_note('Checkpointing...') chrono.pause() train_utils.checkpointing_timeout( checkpoint_writer, config.get('checkpoint_timeout', 1)) accumulated_train_time = chrono.accum_train_time # We need to transfer the weights over now or else we risk keeping them # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) # Check whether we want to keep a copy of the current checkpoint. copy_step = None if train_utils.itstime(step, config.get('keep_checkpoint_steps'), total_steps): write_note('Keeping a checkpoint copy...') copy_step = step # Checkpoint should be a nested dictionary or FLAX datataclasses from # `flax.struct`. Both can be present in a checkpoint. checkpoint_data = checkpoint_utils.CheckpointData( optimizer=opt_cpu, train_loop_rngs=train_loop_rngs, accumulated_train_time=accumulated_train_time) checkpoint_writer = pool.apply_async( checkpoint_utils.checkpoint_trained_model, (checkpoint_data, save_checkpoint_path, copy_step)) chrono.resume() # Report training progress if not config.get('only_eval', False) and train_utils.itstime( step, config.log_training_steps, total_steps, process=0): write_note('Reporting training progress...') train_loss = loss_value[ 0] # Keep to return for reproducibility tests. timing_measurements, note = chrono.tick(step) write_note(note) train_measurements = {} train_measurements.update({ 'learning_rate': lr_repl[0], 'training_loss': train_loss, }) train_measurements.update( flax.jax_utils.unreplicate(extra_measurements)) train_measurements.update(timing_measurements) writer.write_scalars(step, train_measurements) # Report validation performance if train_utils.itstime(step, config.log_eval_steps, total_steps): write_note('Evaluating on the validation set...') chrono.pause() for val_name, val_ds in val_ds_splits.items(): # Sets up evaluation metrics. ece_num_bins = config.get('ece_num_bins', 15) auc_num_bins = config.get('auc_num_bins', 1000) ece = rm.metrics.ExpectedCalibrationError( num_bins=ece_num_bins) calib_auc = rm.metrics.CalibrationAUC( correct_pred_as_pos_label=False) oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.005, num_bins=auc_num_bins) oc_auc_1 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.01, num_bins=auc_num_bins) oc_auc_2 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.02, num_bins=auc_num_bins) oc_auc_5 = rm.metrics.OracleCollaborativeAUC( oracle_fraction=0.05, num_bins=auc_num_bins) label_diversity = tf.keras.metrics.Mean() sample_diversity = tf.keras.metrics.Mean() ged = tf.keras.metrics.Mean() # Runs evaluation loop. val_iter = input_utils.start_input_pipeline( val_ds, config.get('prefetch_to_device', 1)) ncorrect, loss, nseen = 0, 0, 0 for batch in val_iter: if val_name == 'cifar_10h': batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( cifar_10h_evaluation_fn(opt_repl.target, batch['image'], batch['labels'], batch['mask'])) else: batch_ncorrect, batch_losses, batch_n, batch_metric_args = ( evaluation_fn(opt_repl.target, batch['image'], batch['labels'], batch['mask'])) # All results are a replicated array shaped as follows: # (local_devices, per_device_batch_size, elem_shape...) # with each local device's entry being identical as they got psum'd. # So let's just take the first one to the host as numpy. ncorrect += np.sum(np.array(batch_ncorrect[0])) loss += np.sum(np.array(batch_losses[0])) nseen += np.sum(np.array(batch_n[0])) # Here we parse batch_metric_args to compute uncertainty metrics. # (e.g., ECE or Calibration AUC). logits, labels, _, masks = batch_metric_args masks = np.array(masks[0], dtype=np.bool) logits = np.array(logits[0]) probs = jax.nn.softmax(logits) # From one-hot to integer labels, as required by ECE. int_labels = np.argmax(np.array(labels[0]), axis=-1) int_preds = np.argmax(logits, axis=-1) confidence = np.max(probs, axis=-1) for p, c, l, d, m, label in zip(probs, confidence, int_labels, int_preds, masks, labels[0]): ece.add_batch(p[m, :], label=l[m]) calib_auc.add_batch(d[m], label=l[m], confidence=c[m]) # TODO(jereliu): Extend to support soft multi-class probabilities. oc_auc_0_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_1.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_2.add_batch(d[m], label=l[m], custom_binning_score=c[m]) oc_auc_5.add_batch(d[m], label=l[m], custom_binning_score=c[m]) if val_name == 'cifar_10h' or val_name == 'imagenet_real': batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance( label[m], p[m, :], config.num_classes) label_diversity.update_state(batch_label_diversity) sample_diversity.update_state( batch_sample_diversity) ged.update_state(batch_ged) val_loss[ val_name] = loss / nseen # Keep for reproducibility tests. val_measurements = { f'{val_name}_prec@1': ncorrect / nseen, f'{val_name}_loss': val_loss[val_name], f'{val_name}_ece': ece.result()['ece'], f'{val_name}_calib_auc': calib_auc.result()['calibration_auc'], f'{val_name}_oc_auc_0.5%': oc_auc_0_5.result()['collaborative_auc'], f'{val_name}_oc_auc_1%': oc_auc_1.result()['collaborative_auc'], f'{val_name}_oc_auc_2%': oc_auc_2.result()['collaborative_auc'], f'{val_name}_oc_auc_5%': oc_auc_5.result()['collaborative_auc'], } writer.write_scalars(step, val_measurements) if val_name == 'cifar_10h' or val_name == 'imagenet_real': cifar_10h_measurements = { f'{val_name}_label_diversity': label_diversity.result(), f'{val_name}_sample_diversity': sample_diversity.result(), f'{val_name}_ged': ged.result(), } writer.write_scalars(step, cifar_10h_measurements) # OOD eval # There are two entries in the ood_ds dict (in-dist, ood), and this # section computes metrics using both pieces. This is in contrast to # normal validation eval above where we eval metrics separately for each # val split in val_ds. if ood_ds and config.ood_methods: ood_measurements = ood_utils.eval_ood_metrics( ood_ds, ood_ds_names, config.ood_methods, evaluation_fn, opt_repl.target, n_prefetch=config.get('prefetch_to_device', 1)) writer.write_scalars(step, ood_measurements) chrono.resume() if 'fewshot' in config and fewshotter is not None: # Compute few-shot on-the-fly evaluation. if train_utils.itstime(step, config.fewshot.log_steps, total_steps): chrono.pause() write_note(f'Few-shot evaluation...\n{chrono.note}') # Keep `results` to return for reproducibility tests. fewshot_results, best_l2 = fewshotter.run_all( opt_repl.target, config.fewshot.datasets) # TODO(dusenberrymw): Remove this once fewshot.py is updated. def make_writer_measure_fn(step): def writer_measure(name, value): writer.write_scalars(step, {name: value}) return writer_measure fewshotter.walk_results(make_writer_measure_fn(step), fewshot_results, best_l2) chrono.resume() # End of step. if config.get('testing_failure_step'): # Break early to simulate infra failures in test cases. if config.testing_failure_step == step: break write_note(f'Done!\n{chrono.note}') pool.close() pool.join() writer.close() # Return final training loss, validation loss, and fewshot results for # reproducibility test cases. return train_loss, val_loss, fewshot_results