def train_and_evaluate(config, work_dir, try_checkpoint=True): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. work_dir: Directory where the tensorboard summaries are written to. try_checkpoint: Should try to load checkpoint (usually enabled, practical for debugging purposes to disable). Returns: The train state (which includes the `.params`). """ # Init rng key. msg = f'Running with seed {config.seed}.' logging.info(msg) rng = jax.random.PRNGKey(config.seed) data_rng, rng = jax.random.split(rng) is_first_host = jax.process_index() == 0 train_ds, test_ds, shape, num_classes = datasets.get_dataset( config, data_rng) # config.mask_shape = mask_shape config.data_shape = shape config.num_classes = num_classes writer = metric_writers.create_default_writer( work_dir, just_logging=jax.process_index() > 0) rng, init_rng = jax.random.split(rng) # Create output directory for saving samples. output_path = work_dir tf.io.gfile.makedirs(output_path) model, variables = model_setup(init_rng, config) # From now on we want different rng across hosts: rng = jax.random.fold_in(rng, jax.process_index()) tx = optax.adam(config.learning_rate, b1=0.9, b2=config.beta2, eps=1e-08, eps_root=0.0) state = custom_train_state.TrainState.create(params=variables['params'], tx=tx) if try_checkpoint: state, start_epoch = checkpoint.restore_from_path(work_dir, state) if start_epoch is None: start_epoch = 1 else: # For debugging we start at zero, so we immediately do detailed eval. start_epoch = 0 if is_first_host and start_epoch == 1: config_dict = dict(config) writer.write_hparams(config_dict) if is_first_host and start_epoch in (0, 1): # Dump config file to work dir for easy model loading. config_path = os.path.join(work_dir, 'config') with tf.io.gfile.GFile(config_path, 'wb') as fp: pickle.dump(config, fp) test_rng, train_rng = jax.random.split(rng) kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps) kl_history = [] p_train_step = jax.pmap(functools.partial(train_step, model=model, config=config), axis_name='batch', in_axes=(None, 0, 0), out_axes=(0, 0, None), donate_argnums=(2, )) # The only axes that are broadcasted are the in- and output rng key ones. The # rng is the first arg, and the last return value. p_eval_step = jax.pmap(functools.partial(eval_step, model=model), axis_name='batch', in_axes=(None, 0, 0), out_axes=(0, None)) # Replicate state. state = flax.jax_utils.replicate(state) with metric_writers.ensure_flushes(writer): for epoch in range(start_epoch, config.num_epochs + 1): # Train part. state, train_metrics, train_rng = train_epoch( p_train_step, state, train_ds, config.batch_size, epoch, train_rng, kl_tracker_train) # Val part. eval_metrics, test_rng = eval_model(p_eval_step, test_rng, state, test_ds, epoch) # Metric logging. if is_first_host: log_standard_metrics(writer, train_metrics, eval_metrics, epoch) kl_values = kl_tracker_train.get_kl_per_t() kl_history.append(np.array(kl_values)) # Prune to avoid too much memory consumption. kl_history = kl_history[-50:] if epoch == 15 or epoch % config.detailed_eval_every == 0: if is_first_host: loss_components_path = os.path.join( work_dir, 'loss_components') with tf.io.gfile.GFile(loss_components_path, 'wb') as fp: pickle.dump(kl_history[-1], fp) test_rng = extensive_eval(config, test_rng, writer, output_path, model, state, kl_history, test_ds, epoch) # Save to checkpoint. if is_first_host and epoch % config.save_every == 0: # Save to epoch + 1 since current epoch has just been completed. logging.info('saving checkpoint') checkpoint.save_checkpoint( work_dir, state=flax.jax_utils.unreplicate(state), step=epoch + 1, keep=2) logging.info('finished saving checkpoint') return state
def generate(config: ml_collections.ConfigDict): """Generates memories.""" # Establish host information local_device_count = jax.local_device_count() device_count = jax.device_count() process_count = jax.process_count() process_index = jax.process_index() task = memory_generation_task.MemoryGenerationTask model_config = ml_collections.FrozenConfigDict(config.model_config) model = task.build_model(model_config) p_predict_step = jax.pmap(functools.partial( task.make_prediction_fn(config), model_config, ), axis_name='batch') rng = jax.random.PRNGKey(config.seed) # Initialization needs to be pmapped because models use collective ops. # Create dummy input dummy_input = { key: jnp.tile(value, (local_device_count, ) + (1, ) * value.ndim) for key, value in task.dummy_input(config).items() } rng, init_rng = jax.random.split(rng) init_rng = jax.random.split(init_rng, local_device_count) logging.info('Initializing model.') initial_variables = jax.pmap(model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input, True) logging.info('Finished initializing model.') initial_variables = initial_variables.unfreeze() if config.load_weights is not None: logging.info('Loading model weights from file') loaded_variables = task.load_weights(config) unexpected, missing = checkpoint_utils.merge_nested_dicts( initial_variables, loaded_variables) logging.info('*** Unexpected features: ***') for feature_name in unexpected: logging.info('\t%s', feature_name) # In the prediction mode we don't allow any features to be missing # pylint: disable=g-explicit-length-test if len(missing) > 0: raise ValueError('Missing features: %s' % ','.join(missing)) # model_params = jax_utils.unreplicate(initial_variables['params']) model_params = initial_variables['params'] model_vars = { key: value for key, value in initial_variables.items() if key != 'params' } # We access model params only from train state. del initial_variables writer = metric_writers.create_default_writer( config.output_dir, just_logging=process_index > 0) max_length = config.get('max_length_with_entity_tokens', model_config.encoder_config.max_length) num_total_memories = math.ceil(config.num_total_memories / process_count) memory_saver = memory_generation_task.MemorySaver( num_total_memories=num_total_memories, memory_dim=config.memory_dim, max_length=max_length, max_mentions_per_sample=config.max_mentions_per_sample, memory_key_dim=config.get('memory_key_dim')) n_samples = 0 data_iter = get_data_iterator(config) logging.info('Start memory generation.') with metric_writers.ensure_flushes(writer): for step, batch in enumerate(data_iter): batch = jax.tree_map(jnp.asarray, batch) predictions = p_predict_step( model_params, model_vars, batch, ) predictions = jax.device_get(predictions) memory_saver.add_memories(batch, predictions) n_devices, batch_size, _ = batch['text_ids'].shape logging.log_first_n( logging.INFO, 'Process %d / %d: ' 'Finished generating step %d, local devices %d, batch size %d', 5, process_index, process_count, step, n_devices, batch_size) n_samples += device_count * config.per_device_batch_size if (step % config.log_every_steps == 0 or memory_saver.get_num_memories() >= num_total_memories): writer.write_scalars( step, dict(n_memories=memory_saver.get_num_memories(), n_samples=n_samples)) if memory_saver.get_num_memories() >= num_total_memories: break logging.info('Process %d / %d: Finished generating memories: %d out of %d', process_index, process_count, memory_saver.get_num_memories(), num_total_memories) start_time = time.time() logging.info('Process %d / %d: Start saving generated memories to files.', process_index, process_count) memory_saver.save(config.output_dir, num_shards=config.num_shards, stride=process_count, offset=process_index, shard_size_divisible=config.shard_size_divisible) logging.info( 'Process %d / %d: Finished saving generated memories to files in %.2f seconds', process_index, process_count, time.time() - start_time)
def train_and_evaluate(self, workdir): """Runs a training and evaluation loop. Args: workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) config = self.config substeps = config.training.substeps # Learning rate schedule. num_train_steps = config.training.num_train_steps logging.info('num_train_steps=%d', num_train_steps) # Get train state state = self._train_state # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=5) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) # Distribute training. state = flax_utils.replicate(state) writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 0: writer.write_hparams(dict(config)) logging.info('Starting training loop at step %d.', initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] step = initial_step with metric_writers.ensure_flushes(writer): while step < num_train_steps: # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step + substeps >= num_train_steps with jax.profiler.StepTraceAnnotation('train', step_num=step): inputs = jax.tree_map(np.asarray, next(self._train_iter)) state, outputs = self._update_func(state, inputs) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) new_step = int(state.step[0]) assert new_step == step + substeps step = new_step is_eval = step % config.logs.eval_full_every_steps == 0 or is_last_step if step % config.logs.log_loss_every_steps == 0 and not is_eval: def avg_over_substeps(x): assert x.shape[0] == substeps return float(x.mean(axis=0)) # Extract scalars and images. outputs = flax_utils.unreplicate(outputs) outputs = jax.tree_map(avg_over_substeps, outputs) scalars = outputs['scalars'] writer.write_scalars(step, scalars) if is_eval: with report_progress.timed('eval_full'): outputs = self._eval_epoch(params=state.ema_params) outputs = flax_utils.unreplicate(outputs) scalars = outputs['scalars'] writer.write_scalars(step, scalars) if step % config.logs.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed('checkpoint'): ckpt.save(flax_utils.unreplicate(state)) logging.info('Finishing training at step %d', num_train_steps)
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ is_first_process = jax.process_index() == 0 tf.io.gfile.makedirs(workdir) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets( config) config.seq_length = 250 vocab_size = int(encoder.vocab_size()) config.num_classes = vocab_size config.data_shape = (config.seq_length, 1) logging.info('Training with vocab size %d', vocab_size) def decode_tokens(toks): return encoder.detokenize(toks) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) config.per_device_batch_size = config.batch_size // jax.process_count() logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- model, initial_variables = model_setup(init_rng, config) # Instead of passing the optimizer fns directly, we use a fn that returns # the optimizer given a learning rate. def tx_fn(lr): return optax.adamw( lr, b1=0.9, b2=0.99, eps=1e-08, eps_root=0.0, weight_decay=config.weight_decay) state = language_train_state.TrainState.create( params=initial_variables['params'], tx_fn=tx_fn) # We access model params only from state below via state.params. del initial_variables if config.restore_checkpoints: # Restore unreplicated model state from last checkpoint. state = checkpoints.restore_checkpoint(workdir, state) # Grab last step. start_step = int(state.step) writer = metric_writers.create_default_writer( workdir, just_logging=not is_first_process) if start_step == 0: config_dict = dict(config) writer.write_hparams(config_dict) if is_first_process and start_step == 0: # Dump config file to work dir for easy model loading. config_path = os.path.join(workdir, 'config') with tf.io.gfile.GFile(config_path, 'wb') as fp: pickle.dump(config, fp) print('Using state', type(state)) # Replicate state. state = jax_utils.replicate(state) learning_rate_fn = create_learning_rate_scheduler( factors=config.lr_factors, base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # Compile multidevice versions of train/eval/predict step fn. p_train_step = jax.pmap( functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, clip_grad=config.clip_grad, ema_momentum=config.get('ema_momentum', 0.999)), axis_name='batch', in_axes=(0, 0), donate_argnums=(0,)) p_eval_step = jax.pmap( functools.partial( eval_step, model=model), axis_name='batch') # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of train PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rng = jax.random.fold_in(rng, jax.process_index()) rng1, rng2, rng3, extensive_eval_rngs, sample_rng = jax.random.split(rng, 5) train_rngs = jax.random.split(rng1, jax.local_device_count()) eval_rngs = jax.random.split(rng2, jax.local_device_count()) test_rngs = jax.random.split(rng3, jax.local_device_count()) del rng, rng1, rng2, rng3 logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if is_first_process: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] # Iterator that does epoch-wise indefinite iteration. def iterate_train(train_ds): epoch = 1 while True: msg = f'Starting epoch {epoch}' logging.info(msg) for batch in train_ds: yield batch epoch += 1 train_iter = iterate_train(train_ds) kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps) kl_history = [] with metric_writers.ensure_flushes(writer): step = start_step for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) state, metrics = p_train_step( state, batch, rng=train_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Periodic metric handling. if step > 0 and (step % config.eval_every_steps == 0 or is_last_step): with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) # First handle loss terms per step. t_batch = train_metrics.pop('t_batch') nelbo_per_t_batch = train_metrics.pop('nelbo_per_t_batch') kl_tracker_train.update( t_batch.reshape(-1), nelbo_per_t_batch.reshape(-1)) kl_values = kl_tracker_train.get_kl_per_t() kl_history.append(np.array(kl_values)) kl_history = kl_history[-100:] # Keep last 100 items only. # Handle remaining `standard` metrics summary = jax.tree_map(jnp.mean, train_metrics) summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('eval'): eval_results, eval_rngs = evaluate( p_eval_step=p_eval_step, params=state.ema_params, eval_ds=eval_ds, rng=eval_rngs) writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}) test_results, test_rngs = evaluate( p_eval_step=p_eval_step, params=state.ema_params, eval_ds=test_ds, rng=test_rngs) writer.write_scalars( step, {'test_' + k: v for k, v in test_results.items()}) if step == 1000 or (step > 0 and step % config.detailed_eval_every_steps == 0): if is_first_process: loss_components_path = os.path.join(workdir, 'loss_components') with tf.io.gfile.GFile(loss_components_path, 'wb') as fp: pickle.dump(kl_history[-1], fp) extensive_eval_rngs = extensive_eval( config, extensive_eval_rngs, writer, workdir, model, state, kl_history, test_ds, step, decode_tokens) with report_progress.timed('generate_text'): generate_prediction(sample_rng, config, model, state, writer, decode_tokens, step) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( step > 0 and (step % config.checkpoint_every_steps == 0 or is_last_step)) if config.save_checkpoints and save_checkpoint and is_first_process: with report_progress.timed('checkpoint'): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(state), step, overwrite=True)
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ logging.info('Starting training at %s', workdir) tf.io.gfile.makedirs(workdir) if jax.process_index() == 0: with tf.io.gfile.GFile(os.path.join(workdir, 'config.json'), 'w') as f: json.dump(config.to_dict(), f, indent=2) rng = jax.random.PRNGKey(config.seed) # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) train_ds, eval_ds = input_pipeline.create_datasets(config.dataset, data_rng) train_iter = iter(train_ds) test_ds = [] for split in config.dataset.test_splits: ds = input_pipeline.create_val_dataset( config.dataset, split, config.dataset.test_per_device_batch_size, config.dataset.test_pad_last_batch) test_ds.append(ds) # Learning rate schedule. num_train_steps = config.num_train_steps if num_train_steps == -1: num_train_steps = train_ds.cardinality().numpy() steps_per_epoch = num_train_steps // config.dataset.num_epochs logging.info('num_train_steps=%d, steps_per_epoch=%d', num_train_steps, steps_per_epoch) learning_rate_fn = functools.partial( train_utils.get_learning_rate, base_learning_rate=config.learning_rate, num_train_steps=num_train_steps, schedule_type=config.learning_rate_schedule, warmup_proportion=config.warmup_proportion, step_boundaries=config.learning_rate_step_boundaries) # Initialize model. inputs = train_utils.get_init_inputs(train_ds) rng, model_rng = jax.random.split(rng) eval_config = models.TransformerConfig(**config.model.to_dict()) train_config = eval_config.replace(deterministic=False) model = models.Model(eval_config) state = train_utils.create_train_state(model, config, model_rng, inputs=inputs) # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Distribute training. state = flax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, grad_clip=config.grad_clip), axis_name='batch', donate_argnums=(0, )) p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config), axis_name='batch') writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 1: writer.write_hparams(train_utils.flatten_config(config)) logging.info('Starting training loop at step %d.', initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile( num_profile_steps=config.num_profile_steps, logdir=workdir) ] rng, train_rngs = jax.random.split(rng) train_rngs = jax.random.fold_in(train_rngs, jax.process_index()) train_rngs = jax.random.split(train_rngs, jax.local_device_count()) train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): is_last_step = step == num_train_steps with jax.profiler.StepTraceContext('train', step_num=step): batch = jax.tree_map(np.asarray, next(train_iter)) state, metrics = p_train_step(batch=batch, rng=train_rngs, state=state) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) if config.log_loss_every_steps > 0 and ( step % config.log_loss_every_steps == 0 or is_last_step): train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop('learning_rate').mean() train_summary = train_utils.metrics_summary( train_metrics, 'train') train_summary['learning_rate'] = lr writer.write_scalars(step, train_summary) train_metrics = [] if config.eval_every_steps > 0 and (step % config.eval_every_steps == 0 or is_last_step): with report_progress.timed('eval'): eval_summary = evaluate(p_eval_step, state, eval_ds, config.num_eval_steps) writer.write_scalars(step, eval_summary) if config.checkpoint_every_steps > 0 and ( step % config.checkpoint_every_steps == 0 or is_last_step): with report_progress.timed('checkpoint'): ckpt.save(flax_utils.unreplicate(state)) logging.info('Checkpoint saved to %s', checkpoint_dir) logging.info('Finishing training at step %d', num_train_steps)
def train_and_evaluate(config, work_dir, try_checkpoint=True): """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. work_dir: Directory where the tensorboard summaries are written to. try_checkpoint: Should try to load checkpoint (usually enabled, practical for debugging purposes to disable). Returns: The train state (which includes the `.params`). """ # Init rng key. rng = jax.random.PRNGKey(config.seed) data_rng, rng = jax.random.split(rng) is_first_host = jax.process_index() == 0 if config.dataset.name.endswith('speech_commands09'): ds, ds_metadata = input_pipeline_sc09.get_dataset(data_rng, config) else: raise ValueError(f'Unknown dataset {config.dataset.name}.') # Immediately create infinite iterators. it = jax.tree_map(util_fns.get_iterator, ds) # TODO(agritsenko): Can we fix the ugly nested dicts? config.data_shape = ds_metadata['train']['shape']['inputs'][2:] config.num_classes = ds_metadata['train']['num_classes'] config.sample_rate = ds_metadata['train']['sample_rate'] writer = metric_writers.create_default_writer( work_dir, just_logging=jax.process_index() > 0) rng, init_rng = jax.random.split(rng) model, variables = model_setup(init_rng, config) # From now on we want different rng across hosts: rng = jax.random.fold_in(rng, jax.process_index()) def tx_fn(lr): return optax.adamw(lr, b1=0.9, b2=config.beta2, eps=1e-08, eps_root=0.0, weight_decay=config.weight_decay) state = language_train_state.TrainState.create(params=variables['params'], tx_fn=tx_fn) start_step = None if try_checkpoint: state, start_step = checkpoint.restore_from_path(work_dir, state) start_step = start_step or 0 # Use different rngs for train & eval. rng_train, rng_eval, rng_sample = jax.random.split(rng, 3) kl_tracker = util_fns.KLTracker(num_steps=model.num_steps) kl_history = [] learning_rate_fn = train_utils.create_learning_rate_scheduler( **config.learning_rate) p_train_step = jax.pmap(functools.partial( train_step, config=config, learning_rate_fn=learning_rate_fn, model=model), axis_name='batch', in_axes=(None, 0, 0), out_axes=(0, 0, None), donate_argnums=(2, )) # The only axes that are broadcasted are the in- and output rng key ones. The # rng is the first arg, and the last return value. p_eval_step = jax.pmap(functools.partial(eval_step, model=model), axis_name='batch', in_axes=(None, 0, 0), out_axes=(0, 0, None)) # Training length. logging.info('Training will start from step %d', start_step) # Replicate state. state = flax.jax_utils.replicate(state) # Setup hooks. hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if is_first_host: hooks += [ report_progress, periodic_actions.Profile(logdir=work_dir, num_profile_steps=5) ] with metric_writers.ensure_flushes(writer): batch_metrics = [] for step in range(start_step, config.num_train_steps): logging.log_first_n(logging.INFO, f'Train step: {step}', 5) with jax.profiler.StepTraceAnnotation('train', step_num=step): state, metrics, rng_train = p_train_step( rng_train, next(it['train']), state) batch_metrics.append(metrics) # Cycle though hooks. for h in hooks: h(step) is_last_step = step == config.num_train_steps - 1 if (step % config.log_every_steps == 0) or is_last_step: with report_progress.timed('training_metrics'): ################### Process batch metrics ############################ batch_metrics = jax.device_get( flax.jax_utils.unreplicate(batch_metrics)) if 't_batch' in metrics: # TODO(agritsenko): Factor out into a separate function. # This processes the loss per t, although two nested for-loops # (counting the one inside kl_tracker), it actually does not hurt # timing performance meaningfully. batch_t = [ metrics['t_batch'].reshape(-1) for metrics in batch_metrics ] batch_nelbo_per_t = [ metrics['nelbo_per_t_batch'].reshape(-1) for metrics in batch_metrics ] for t, nelbo_per_t in zip(batch_t, batch_nelbo_per_t): kl_tracker.update(t, nelbo_per_t) ################### Process batch metrics ############################ metrics = { key: np.mean([metrics[key] for metrics in batch_metrics]) for key in batch_metrics[0] if 'batch' not in key } # Metric logging. if is_first_host: log_standard_metrics(writer, step, train_metrics=metrics) batch_metrics = [] if config.eval_every_steps and ( (step % config.eval_every_steps == 0) or is_last_step): with report_progress.timed('eval'): ####################### Run evaluation ############################### metrics, rng_eval = eval_model( p_eval_step, rng_eval, state, it['eval'], (ds_metadata['eval']['num_batches'] * config.get('num_eval_passes', 1))) # Metric logging. if is_first_host: log_standard_metrics(writer, step, eval_metrics=metrics) # Track KL (unrelated to the eval, but nice to not do every step). kl_values = kl_tracker.get_kl_per_t() kl_history.append(np.array(kl_values)) kl_history = kl_history[-50:] if config.sample_every_steps and ( (step % config.sample_every_steps == 0) or is_last_step): with report_progress.timed('sample'): ######################### Run sampling ############################### chain = model.sample(jax.random.fold_in(rng_sample, step), state.ema_params, config.sample_batch_size, chain_out_size=config.get( 'chain_out_size', model.num_stages)) if is_first_host: chain = jax.device_get(chain) long_sample = np.reshape(chain[-1], (1, -1, 1)).astype(np.float32) long_sample = (2. * long_sample) / config.num_classes - 1. writer.write_audios(step, {'samples': long_sample}, sample_rate=config.sample_rate) ######################### Checkpointing ################################# if is_first_host and config.checkpoint_every_steps and ( (step % config.checkpoint_every_steps == 0) or is_last_step): logging.info('Saving checkpoint: step %d', step) with report_progress.timed('checkpoint'): checkpoint.save_checkpoint( work_dir, state=flax.jax_utils.unreplicate(state), step=step) logging.info('Finished saving checkpoint: step %d', step) return state
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ if config.dataset.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") tf.io.gfile.makedirs(workdir) # Deterministic training. rng = jax.random.PRNGKey(config.seed) # Shift the numpy random seed by process_index() to shuffle data loaded # by different hosts np.random.seed(20201473 + jax.process_index()) #---------------------------------------------------------------------------- # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) config.dataset.data_dir = os.path.join(config.dataset.base_dir, config.dataset.scene) train_ds, eval_ds = datasets.create_dataset(config) example_batch = train_ds.peek() #---------------------------------------------------------------------------- # Learning rate schedule. num_train_steps = config.train.max_steps if num_train_steps == -1: num_train_steps = train_ds.size() steps_per_epoch = num_train_steps // config.train.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) learning_rate_fn = train_utils.create_learning_rate_fn(config) #---------------------------------------------------------------------------- # Initialize model. rng, model_rng = jax.random.split(rng) model, state = models.create_train_state( config, model_rng, learning_rate_fn=learning_rate_fn, example_batch=example_batch, ) #---------------------------------------------------------------------------- # Set up checkpointing of the model and the input pipeline. state = checkpoints.restore_checkpoint(workdir, state) initial_step = int(state.step) + 1 #---------------------------------------------------------------------------- # Distribute training. state = flax_utils.replicate(state) p_train_step = jax.pmap( functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, weight_decay=config.train.weight_decay, config=config, ), axis_name="batch", ) # Get distributed rendering function render_pfn = render_utils.get_render_function( model=model, config=config, randomized=False, # No randomization for evaluation. ) #---------------------------------------------------------------------------- # Prepare Metric Writers writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 1: writer.write_hparams(dict(config)) logging.info("Starting training loop at step %d.", initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, ] train_metrics = None # Prefetch_buffer_size = 6 x batch_size ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6) n_local_devices = jax.local_device_count() rng = rng + jax.process_index() # Make random seed separate across hosts. keys = jax.random.split(rng, n_local_devices) # For pmapping RNG keys. with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step == num_train_steps with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = next(ptrain_ds) state, metrics_update, keys = p_train_step(rng=keys, state=state, batch=batch) metric_update = flax_utils.unreplicate(metrics_update) train_metrics = (metric_update if train_metrics is None else train_metrics.merge(metric_update)) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) if step % config.train.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, train_metrics.compute()) train_metrics = None if step % config.train.render_every_steps == 0 or is_last_step: test_batch = next(eval_ds) test_pixels = model_utils.uint2float( test_batch.target_view.rgb) # extract for evaluation with report_progress.timed("eval"): pred_color, pred_disp, pred_acc = eval_step( state, keys[0], test_batch, render_pfn, config) #------------------------------------------------------------------ # Log metrics and images for host 0 #------------------------------------------------------------------ if jax.process_index() == 0: psnr = model_utils.compute_psnr( ((pred_color - test_pixels)**2).mean()) ssim = skmetrics.structural_similarity( pred_color.astype(np.float32), test_pixels.astype(np.float32), win_size=11, multichannel=True, gaussian_weight=True) writer.write_scalars( step, { "train_eval/test_psnr": psnr, "train_eval/test_ssim": ssim, }) writer.write_images( step, { "test_pred_color": pred_color[None, :], "test_target": test_pixels[None, :] }) if pred_disp is not None: writer.write_images( step, {"test_pred_disp": pred_disp[None, :]}) if pred_acc is not None: writer.write_images( step, {"test_pred_acc": pred_acc[None, :]}) #------------------------------------------------------------------ if (jax.process_index() == 0) and (step % config.train.checkpoint_every_steps == 0 or is_last_step): # Write final metrics to file with file_utils.open_file( os.path.join(workdir, "train_logs.json"), "w") as f: log_dict = metric_update.compute() for k, v in log_dict.items(): log_dict[k] = v.item() f.write(json.dumps(log_dict)) with report_progress.timed("checkpoint"): state_to_save = jax.device_get( jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(workdir, state_to_save, step, keep=100) logging.info("Finishing training at step %d", num_train_steps)
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) # Number of local devices for this host. n_devices = jax.local_device_count() if config.batch_size % n_devices: raise ValueError( "Batch size must be divisible by the number of devices") vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=config.dataset_name, eval_dataset_name=config.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), vocab_path=vocab_path, target_vocab_size=config.vocab_size, batch_size=config.batch_size, max_corpus_chars=config.max_corpus_chars, max_length=config.max_target_length, max_eval_length=config.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) input_shape = (config.batch_size, config.max_target_length) target_shape = (config.batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 1: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial( eval_step, config=eval_config, label_smoothing=config.label_smoothing), axis_name="batch") p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial(predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5) ] metrics_all = [] with metric_writers.ensure_flushes(writer): for step, batch in zip(range(start_step, config.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Save a checkpoint on one host after every checkpoint_freq steps. if (config.save_checkpoints and step % config.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % config.eval_frequency != 0 and step > 0: continue # Training Metrics logging.info("Gathering training metrics.") metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) metrics_all = [] # Eval Metrics logging.info("Gathering evaluation metrics.") eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(config.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop("denominator") eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) eval_summary = {"eval_" + k: v for k, v in eval_summary.items()} writer.write_scalars(step, eval_summary) # Translation and BLEU Score. logging.info("Translating evaluation dataset.") t_inference_start = time.time() sources, references, predictions = [], [], [] for pred_batch in predict_ds: pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch["inputs"].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), # pylint: disable=cell-var-from-loop pred_batch) pred_batch = common_utils.shard(pred_batch) cache = p_init_cache(pred_batch["inputs"]) predicted = p_pred_step(pred_batch["inputs"], optimizer.target, cache, eos_id, config.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch["inputs"]) targets = tohost(pred_batch["targets"]) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info( "Translation: %d predictions %d references %d sources.", len(predictions), len(references), len(sources)) logging.info("Translation time: %.4f s step %d.", time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = "" for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n" writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars})
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) rng = jax.random.PRNGKey(config.seed) # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.host_id()) splits = input_pipeline.create_datasets(config, data_rng) num_classes = splits.info.features["label"].num_classes train_iter = iter(splits.train) # pytype: disable=wrong-arg-types # Learning rate schedule. num_train_steps = config.num_train_steps if num_train_steps == -1: num_train_steps = splits.train.cardinality().numpy() steps_per_epoch = num_train_steps // config.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) # We treat the learning rate in the config as the learning rate for batch size # 32 but scale it according to our batch size. global_batch_size = config.per_device_batch_size * jax.device_count() base_learning_rate = config.learning_rate * global_batch_size / 32.0 learning_rate_fn = functools.partial(get_learning_rate, base_learning_rate=base_learning_rate, steps_per_epoch=steps_per_epoch, num_epochs=config.num_epochs, warmup_epochs=config.warmup_epochs) # Initialize model. rng, model_rng = jax.random.split(rng) model, state = create_train_state( config, model_rng, input_shape=splits.train.element_spec["input"].shape[1:], num_classes=num_classes) # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, {"train_iter": train_iter}, max_to_keep=2) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) + 1 # Count number of trainable parameters. This must be done before replicating # the state to avoid double-counting replicated parameters. param_count = sum(p.size for p in jax.tree_leaves(state.optimizer.target)) # Distribute training over local devices. state = flax_utils.replicate(state) p_train_step = jax.pmap(functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, weight_decay=config.weight_decay), axis_name=_PMAP_AXIS_NAME) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if initial_step == 1: writer.write_hparams(dict(config)) # Log the number of trainable params. writer.write_scalars(initial_step, {"param_count": param_count}) logging.info("Starting training loop at step %d.", initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=workdir) ] train_metrics = None with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step == num_train_steps with jax.profiler.StepTraceContext("train", step_num=step): batch = jax.tree_map(np.asarray, next(train_iter)) state, metrics_update = p_train_step(state=state, batch=batch) metric_update = flax_utils.unreplicate(metrics_update) train_metrics = (metric_update if train_metrics is None else train_metrics.merge(metric_update)) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) if step % config.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, train_metrics.compute()) train_metrics = None # When combining train and eval, we do not evaluate while training. if ((step % config.eval_every_steps == 0 or is_last_step) and not config.combine_train_val_and_eval_on_test): with report_progress.timed("eval"): eval_metrics = evaluate(model, state, splits.validation, config.num_eval_steps) writer.write_scalars(step, eval_metrics.compute()) if step % config.checkpoint_every_steps == 0 or is_last_step: with report_progress.timed("checkpoint"): ckpt.save(flax_utils.unreplicate(state)) if is_last_step and config.combine_train_val_and_eval_on_test: # Evaluate a single time on the test set when requested. with report_progress.timed("test"): test_metrics = evaluate(model, state, splits.test, config.num_eval_steps) writer.write_scalars(step, test_metrics.compute()) logging.info("Finishing training at step %d", num_train_steps)
def train(base_dir, config): """Train function.""" print(config) chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train')) writer = create_default_writer() # Initialize dataset key = jax.random.PRNGKey(config.seed) key, subkey = jax.random.split(key) ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks) ds_iter = iter(ds) key, subkey = jax.random.split(key) encoder = MLPEncoder(**config.encoder) train_config = config.train.to_dict() train_method = train_config.pop('method') module_config = train_config.pop('module') module_class = module_config.pop('name') module = globals().get(module_class)(encoder, **module_config) train_step = globals().get(f'train_step_{train_method}') train_step = functools.partial(train_step, **train_config) params = module.init(subkey, next(ds_iter)[0]) lr = optax.cosine_decay_schedule(config.learning_rate, config.num_train_steps) optim = optax.chain(optax.adam(lr), # optax.adaptive_grad_clip(0.15) ) state = TrainState.create(apply_fn=module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) # Hooks report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = TrainMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_train_steps)): with jax.profiler.StepTraceAnnotation('train', step_num=step): states, targets = next(ds_iter) state, metrics = train_step(state, metrics, states, targets) logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = TrainMetrics.empty() # if step % config.log_eval_metrics_every == 0 and isinstance( # ds, dataset.MDPDataset): # eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config) # writer.write_scalars(step, eval_metrics.compute()) for hook in hooks: hook(step) chkpt_manager.save(state) return state
jax.random.normal( # charlinel(why benefit of np?) weight_key, (d, num_tasks), dtype=jnp.float64)) assert optimizer == 'sgd', 'Non-sgd not yet supported.' writer = metric_writers.create_default_writer(logdir=str(workdir), ) hooks = [ periodic_actions.PeriodicCallback( every_steps=5_000, callback_fn=lambda step, t: chkpt_manager.save((step, Phi))) ] # Perform num_epochs gradient steps. with metric_writers.ensure_flushes(writer): for step in etqdm.tqdm(range(initial_step + 1, num_epochs + 1), initial=initial_step, total=num_epochs): # Draw one or many source states to update, and its task. source_states, key = utils.draw_states(num_states, main_batch_size, key) task, key = utils.draw_states(num_tasks, 1, key) # bad Marc! # Use the source states to update our estimate of the feature norm. # Do this pre-LISSA, avoid a bad first gradient. if method == 'lissa' and estimate_feature_norm: max_norm = utils.compute_max_feature_norm( Phi[source_states, :]) estimated_feature_norm += 0.01 * (max_norm - estimated_feature_norm)
def evaluate(base_dir, config, *, train_state): """Eval function.""" chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval')) writer = create_default_writer() key = jax.random.PRNGKey(config.eval.seed) model_init_key, ds_key = jax.random.split(key) linear_module = LinearModule(config.eval.num_tasks) params = linear_module.init(model_init_key, jnp.zeros((config.encoder.embedding_dim, ))) lr = optax.cosine_decay_schedule(config.eval.learning_rate, config.num_eval_steps) optim = optax.adam(lr) ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks) ds_iter = iter(ds) state = TrainState.create(apply_fn=linear_module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_eval_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = EvalMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_eval_steps)): with jax.profiler.StepTraceAnnotation('eval', step_num=step): states, targets = next(ds_iter) state, metrics = evaluate_step(train_state, state, metrics, states, targets) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = EvalMetrics.empty() for hook in hooks: hook(step) # Finally, evaluate on the true(ish) test aux task matrix. states, targets = dataset.EvalDataset(config, ds_key).get_batch() @jax.jit def loss_fn(): outputs = train_state.apply_fn(train_state.params, states) phis = outputs.phi predictions = jax.vmap(state.apply_fn, in_axes=(None, 0))(state.params, phis) return jnp.mean(optax.l2_loss(predictions, targets)) test_loss = loss_fn() writer.write_scalars(config.num_eval_steps + 1, {'test_loss': test_loss})
def train(config: ml_collections.ConfigDict): """Run training.""" # Establish host information local_device_count = jax.local_device_count() host_count = jax.process_count() host_id = jax.process_index() task = task_registry.get_registered_task(config.task_name) start_step = 0 rng = jax.random.PRNGKey(config.seed) model_config = ml_collections.FrozenConfigDict(config.model_config) model = task.build_model(model_config) # Initialization needs to be pmapped because models use collective ops. # Create dummy input dummy_input = { key: jnp.tile(value, (local_device_count,) + (1,) * value.ndim) for key, value in task.dummy_input(config).items() } rng, init_rng = jax.random.split(rng) init_rng = jax.random.split(init_rng, local_device_count) logging.info('Initializing model.') initial_variables = jax.pmap( model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input, True) logging.info('Finished initializing model.') initial_variables = initial_variables.unfreeze() if config.load_weights is not None: logging.info('Loading model weights from file') loaded_variables = task.load_weights(config) unexpected, missing = checkpoint_utils.merge_nested_dicts( initial_variables, loaded_variables) logging.info('*** Unexpected features: ***') for feature_name in unexpected: logging.info('\t%s', feature_name) logging.info('*** Missing features: ***') for feature_name in missing: logging.info('\t%s', feature_name) model_vars = { key: value for key, value in initial_variables.items() if key != 'params' } learning_rate_fn = optim_utils.create_learning_rate_scheduler( learning_rate=config.learning_rate, warmup=config.warmup, warmup_steps=config.get('warmup_steps', None), linear_decay=config.linear_decay, max_steps=config.num_train_steps, decay_minimum_factor=config.get('decay_minimum_factor', None), ) if config.weight_decay_exclude is not None: decay_mask = optim_utils.create_dict_mask(initial_variables['params'], config.weight_decay_exclude) else: decay_mask = None tx = optax.adamw( learning_rate=learning_rate_fn, weight_decay=config.weight_decay, b1=0.9, b2=0.999, eps=1e-6, mask=decay_mask) if config.grad_clip is not None: tx = optax.chain(tx, optax.clip_by_global_norm(config.grad_clip)) ignore_k_nans = config.get('ignore_k_nans') if ignore_k_nans is not None: tx = optax.apply_if_finite(tx, max_consecutive_errors=ignore_k_nans) loss_fn = task.make_loss_fn(config) train_state = ts.TrainState.create( apply_fn=loss_fn, params=jax_utils.unreplicate(initial_variables['params']), tx=tx, ) # We access model params only from train state. del initial_variables # Restore unreplicated train state from last checkpoint train_state = checkpoints.restore_checkpoint(config.model_dir, train_state) # Grab last step. start_step = int(train_state.step) writer = metric_writers.create_default_writer( config.model_dir, just_logging=jax.process_index() > 0) if start_step == 0: writer.write_hparams(config.to_dict()) dropout_rngs = jax.random.split(rng, local_device_count) del rng # Load datasets logging.info('Loading dataset.') # Make sure we don't re-use same data if we load weights or checkpoint seed = config.seed + start_step if config.load_weights: seed = seed + hash(config.load_weights) name_to_features = task.get_name_to_features(config) preprocess_fn = task.make_preprocess_fn(config) collater_fn = task.make_collater_fn(config) train_data = data_utils.load_multi_dataset( datasets_config=config.train_data, name_to_features=name_to_features, preprocess_fn=preprocess_fn, collater_fn=collater_fn, is_training=True, per_device_batch_size=config.per_device_batch_size, local_device_count=local_device_count, host_count=host_count, host_id=host_id, seed=config.seed, ) train_iter = iter(train_data) pad_eval = config.get('pad_eval', False) if pad_eval: logging.info('Eval data is padded such that none of samples are dropped.') else: logging.warn('Eval data is NOT padded -- some samples might be dropped.') eval_data = data_utils.load_multi_dataset( datasets_config=config.eval_data, name_to_features=name_to_features, preprocess_fn=preprocess_fn, collater_fn=collater_fn, is_training=False, per_device_batch_size=config.per_device_batch_size, local_device_count=local_device_count, host_count=host_count, host_id=host_id, seed=config.seed, pad_eval=pad_eval, ) eval_data = list(eval_data) logging.info('Loaded %d samples for evaluation.', len(eval_data)) # Setup postprocessing_fn for saving samples occasionally. if config.get('save_samples_every_steps') is not None: if config.get('save_samples_every_steps') % config.eval_every_steps != 0: raise ValueError( '`eval_every_steps` must divide `save_samples_every_steps`.') postprocessing_fn = task.make_output_postprocess_fn(config) # Training loop logging.info('Starting training.') # Replicate train state. train_state = jax_utils.replicate(train_state) # compile multidevice versions of train/eval/predict step p_train_step = jax.pmap( functools.partial( train_step, model_config=model_config, ), axis_name='batch', donate_argnums=(0,), ) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, model_config=model_config, ), axis_name='batch') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=config.model_dir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and perform a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): batch = jax.tree_map(jnp.asarray, train_iter.get_next()) train_state, metrics = p_train_step( train_state, model_vars, batch, dropout_rngs, ) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) metrics_sums = jax.tree_map(jnp.sum, train_metrics) summary = metric_utils.process_metrics(metrics_sums, prefix='train') summary['learning_rate'] = learning_rate_fn(step) writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('eval'): eval_results, eval_auxiliary = evaluate( eval_step_fn=p_eval_step, train_state=train_state, model_vars=model_vars, eval_data=eval_data, ) writer.write_scalars(step, eval_results) if config.get('save_samples_every_steps') is not None: with report_progress.timed('save_samples'): if config.get('save_first_batch_only', 'True'): postprocessing_input = [eval_auxiliary[0]] eval_processed = [ postprocessing_fn(batch, auxiliary_output) for batch, auxiliary_output in eval_auxiliary ] data_utils.save_samples_to_json(eval_processed, config, step) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( step % config.checkpoint_every_steps == 0 or is_last_step) if (config.save_checkpoints and save_checkpoint and jax.process_index() == 0): with report_progress.timed('checkpoint'): logging.info('Saving checkpoint at step %s', step) checkpoints.save_checkpoint( config.model_dir, jax_utils.unreplicate(train_state), step, keep=config.get('keep_checkpoints', 1), keep_every_n_steps=config.get('keep_checkpoint_every_steps'), ) save_model = ( config.save_every_steps and (step % config.save_every_steps == 0 or is_last_step) and step != 0) if (save_model and jax.process_index() == 0): with report_progress.timed('checkpoint'): logging.info('Saving weights at step %s', step) save_path = os.path.join(config.model_dir, 'weights', 'step' + str(step)) # By default, save only encoder weights weights = jax_utils.unreplicate(train_state).params['encoder'] checkpoint_utils.save_weights(save_path, weights)
def run_train(self, experiment_dir, work_unit_dir, rng): """Training loop with fixed number of steps and checkpoint every steps.""" del experiment_dir # unused tf.io.gfile.makedirs(work_unit_dir) config = self.config total_bs = config.train.batch_size assert total_bs % jax.device_count() == 0, ( f'num total devices {jax.device_count()} must divide the batch size ' f'{total_bs}') device_bs = total_bs // jax.device_count() logging.info('total_bs=%d device_bs=%d', total_bs, device_bs) # Logging setup writer = metric_writers.create_default_writer( work_unit_dir, just_logging=jax.host_id() > 0) if jax.host_id() == 0: utils.write_config_json(config, os.path.join(work_unit_dir, 'config.json')) # Build input pipeline logging.info('Substeps per training step: %d', config.train.substeps) train_ds = self.dataset.get_tf_dataset( split='train', batch_shape=( jax.local_device_count(), # for pmap config.train.substeps, # for lax.scan over multiple substeps device_bs, # batch size per device ), global_rng=jax.random.PRNGKey(config.seed), repeat=True, shuffle=True, augment=True, shard_id=jax.host_id(), num_shards=jax.host_count()) train_iter = utils.numpy_iter(train_ds) eval_ds = self.dataset.get_tf_dataset( split='eval', batch_shape=(jax.local_device_count(), device_bs), global_rng=jax.random.PRNGKey(config.seed), repeat=True, shuffle=True, augment=False, shard_id=jax.host_id(), num_shards=jax.host_count()) eval_iter = utils.numpy_iter(eval_ds) samples_shape = (device_bs, *self.dataset.data_shape) self.p_gen_samples = utils.dist( functools.partial(self._gen_samples, samples_shape=samples_shape), accumulate='concat', axis_name='batch') # Set up model and training state state = jax.device_get(self.make_init_state()) checkpoint_dir = os.path.join(work_unit_dir, 'checkpoints') state = checkpoints.restore_checkpoint(checkpoint_dir, state) initial_step = int(state.step) state = flax.jax_utils.replicate(state) # Training step train_step = functools.partial(self.step_fn, next(rng), True) train_step = functools.partial(jax.lax.scan, train_step) # for substeps train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,)) # Eval step (does not modify parameters; no substeps) eval_base_rng = next(rng) # Training loop logging.info('Entering training loop at step %i', initial_step) utils.assert_synced(state) last_log_time = last_ckpt_time = time.time() prev_step = initial_step with metric_writers.ensure_flushes(writer): for batch in train_iter: state, metrics = train_step(state, batch) new_step = int(state.step[0]) assert new_step == prev_step + config.train.substeps # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d', 5, new_step) # Log metrics if new_step % config.train.log_loss_every_steps == 0: # Unreplicate metrics, average over substeps, and cast to python float metrics = jax.device_get(flax.jax_utils.unreplicate(metrics)) def avg_over_substeps(x): assert x.shape[0] == config.train.substeps return float(x.mean(axis=0)) metrics = jax.tree_map(avg_over_substeps, metrics) metrics['train/steps_per_sec'] = float( config.train.log_loss_every_steps / (time.time() - last_log_time)) writer.write_scalars(new_step, metrics) last_log_time = time.time() # Eval should_eval = new_step % config.train.eval_every_steps == 0 if prev_step == 0 or should_eval: # Samples samples_to_log = { 'eval/samples': self.get_model_samples( params=state.ema_params, rng=next(rng)) } if samples_to_log: assert all(v.shape == (total_bs, *self.dataset.data_shape) for v in samples_to_log.values()) # tf.summary.image asks for a batch, so insert a new axis writer.write_images( new_step, { k: utils.np_tile_imgs(v.astype('uint8'))[None, :, :, :] for k, v in samples_to_log.items() }) # Eval metrics if config.train.get('calc_eval_metrics', True): eval_metrics = self._calc_eval_metrics( state=state, eval_iter=eval_iter, eval_steps=config.train.get('eval_number_steps', self.dataset.num_eval // total_bs), eval_base_rng=eval_base_rng, total_bs=total_bs) if eval_metrics is not None: writer.write_scalars(new_step, eval_metrics) # Checkpointing: only if checkpoint_every_secs is not None. if config.train.checkpoint_every_secs is not None: should_ckpt = ( time.time() - last_ckpt_time >= config.train.checkpoint_every_secs) should_ckpt = ( prev_step == 0 or new_step == config.train.num_train_steps or should_ckpt) else: should_ckpt = False if should_ckpt and jax.host_id() == 0: checkpoints.save_checkpoint( checkpoint_dir, flax.jax_utils.unreplicate(state), step=new_step, keep=3) last_ckpt_time = time.time() # Keep extra checkpoints without removal. Training does not resume # from these checkpoints. if (('retain_checkpoint_every_steps' in config.train) and ((new_step % config.train.retain_checkpoint_every_steps == 0) or (new_step == config.train.num_train_steps)) and (jax.host_id() == 0)): # Below, overwrite=True because training might resume from a # checkpoint from an earlier step than the latest retained checkpoint, # causing the latest retained checkpoint to be overwritten. checkpoints.save_checkpoint( os.path.join(work_unit_dir, 'retained_checkpoints'), flax.jax_utils.unreplicate(state), step=new_step, keep=int(1e10), overwrite=True) prev_step = new_step if new_step == config.train.num_train_steps: logging.info('Finished training for %d iterations.', new_step) break
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=jax.local_device_count(), config=config, reverse_translation=config.reverse_translation, vocab_path=vocab_path) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) target_shape = (config.per_device_batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 0: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap( functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0,)) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, config=eval_config), axis_name="batch") p_init_cache = jax.pmap( functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial( predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("translate_and_bleu"): exemplars, bleu_score = translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length) writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = (step % config.checkpoint_every_steps == 0 or is_last_step) if config.save_checkpoints and save_checkpoint and jax.host_id() == 0: with report_progress.timed("checkpoint"): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step)
def train(*, workdir, compute_phi, compute_psi, params, optimal_subspace, num_epochs, learning_rate, key, method, lissa_kappa, optimizer, covariance_batch_size, main_batch_size, weight_batch_size, d, num_tasks, compute_feature_norm_on_oracle_states, sample_states, eval_states, use_tabular_gradient=True): """Training function. For lissa, the total number of samples is 2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size. Args: workdir: Work directory, where we'll save logs. compute_phi: A function that takes params and states and returns a matrix of phis. compute_psi: A function that takes an array of states and an array of tasks and returns Psi[states, tasks]. params: Parameters used as the first argument for compute_phi. optimal_subspace: Top-d left singular vectors of Psi. num_epochs: How many gradient steps to perform. (Not really epochs) learning_rate: The step size parameter for sgd. key: The jax prng key. method: 'naive', 'lissa', or 'oracle'. lissa_kappa: The parameter of the lissa method, if used. optimizer: Which optimizer to use. Only 'sgd' is supported. covariance_batch_size: the 'J' parameter. For the naive method, this is how many states we sample to construct the inverse. For the lissa method, ditto -- these are also "iterations". main_batch_size: How many states to update at once. weight_batch_size: How many states to construct the weight vector. d: The dimension of the representation. num_tasks: The total number of tasks. compute_feature_norm_on_oracle_states: If True, computes the feature norm using the oracle states (all the states in synthetic experiments). Otherwise, computes the norm using the sampled batch. Only applies to LISSA. sample_states: A function that takes an rng key and a number of states to sample, and returns a tuple containing (a vector of sampled states, an updated rng key). eval_states: An array of states to use to compute metrics on. This will be used to compute Phi = compute_phi(params, eval_states). use_tabular_gradient: If true, the train step will calculate the gradient using the tabular calculation. Otherwise, it will use a jax.vjp to backpropagate the gradient. """ # Create an explicit weight vector (needed for explicit method only). if method == 'explicit': key, weight_key = jax.random.split(key) explicit_weight_matrix = jax.random.normal(weight_key, (d, num_tasks), dtype=jnp.float32) params['explicit_weight_matrix'] = explicit_weight_matrix if optimizer == 'sgd': optimizer = optax.sgd(learning_rate) elif optimizer == 'adam': optimizer = optax.adam(learning_rate) else: raise ValueError(f'Unknown optimizer {optimizer}.') optimizer_state = optimizer.init(params) chkpt_manager = checkpoint.Checkpoint(base_directory=_WORKDIR.value) initial_step, params, optimizer_state = chkpt_manager.restore_or_initialize( (0, params, optimizer_state)) writer = metric_writers.create_default_writer(logdir=str(workdir), ) # Checkpointing and logging too much can use a lot of disk space. # Therefore, we don't want to checkpoint more than 10 times an experiment, # or keep more than 1k Phis per experiment. checkpoint_period = max(num_epochs // 10, 100_000) log_period = max(1_000, num_epochs // 1_000) def _checkpoint_callback(step, t, params, optimizer_state): del t # Unused. chkpt_manager.save((step, params, optimizer_state)) hooks = [ periodic_actions.PeriodicCallback(every_steps=checkpoint_period, callback_fn=_checkpoint_callback) ] fixed_train_kwargs = { 'compute_phi': compute_phi, 'compute_psi': compute_psi, 'optimizer': optimizer, 'method': method, # In the tabular case, the eval_states are all the states. 'oracle_states': eval_states, 'lissa_kappa': lissa_kappa, 'main_batch_size': main_batch_size, 'covariance_batch_size': covariance_batch_size, 'weight_batch_size': weight_batch_size, 'd': d, 'num_tasks': num_tasks, 'compute_feature_norm_on_oracle_states': (compute_feature_norm_on_oracle_states), 'sample_states': sample_states, 'use_tabular_gradient': use_tabular_gradient, } variable_kwargs = { 'params': params, 'optimizer_state': optimizer_state, 'key': key, } @jax.jit def _eval_step(phi_params): eval_phi = compute_phi(phi_params, eval_states) eval_psi = compute_psi(eval_states) # pytype: disable=wrong-arg-count metrics = compute_metrics(eval_phi, optimal_subspace) metrics |= {'frob_norm': utils.outer_objective_mc(eval_phi, eval_psi)} return metrics # Perform num_epochs gradient steps. with metric_writers.ensure_flushes(writer): for step in etqdm.tqdm(range(initial_step + 1, num_epochs + 1), initial=initial_step, total=num_epochs): variable_kwargs = _train_step(**fixed_train_kwargs, **variable_kwargs) if step % log_period == 0: metrics = _eval_step(variable_kwargs['params']['phi_params']) writer.write_scalars(step, metrics) for hook in hooks: hook(step, params=variable_kwargs['params'], optimizer_state=variable_kwargs['optimizer_state']) writer.flush()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.process_index() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size, split_tokenizer=FLAGS.split_tokenizer) if FLAGS.aux_eval_dataset: aux_datasets = [] aux_names = FLAGS.aux_eval_dataset.split(',') for name in aux_names: _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets( dataset_name=name, eval_dataset_name=None, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size) aux_datasets.append(aux_eval_ds) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(FLAGS.random_seed) rng, init_rng = jax.random.split(rng) # It's possible that is supposed to be per device batch size input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: logging.info('Restoring checkpoint.') # If we have a pretrained model, use that. Else, just continue where leftoff model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir optimizer = checkpoints.restore_checkpoint(model_path, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) flag_key = [ k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k ] if flag_key: flag_key = flag_key[0] local_flags = { f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key] } writer.write_hparams(local_flags) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = common.create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps, steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step, finetune_lr=FLAGS.finetune_lr) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_util.train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing), axis_name='batch', donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial(train_util.eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( train_util.initialize_cache, max_decode_len=FLAGS.max_predict_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial(train_util.predict_step, config=predict_config, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5) ] train_metrics = [] total_steps = start_step + FLAGS.num_train_steps if FLAGS.eval_only: total_steps = start_step + 1 best_eval_loss = 1000 curr_eval_loss = 1000 eval_loss_history = [] last_eval_step = 0 do_resample_data = False gradual_selection_size = FLAGS.data_selection_size dynamic_eval_freq = FLAGS.eval_frequency with metric_writers.ensure_flushes(writer): for step in range(start_step, total_steps): is_last_step = step == total_steps - 1 # Resample training data for gradual FT if do_resample_data: # resample data do_resample_data = False gradual_selection_size *= .7 dynamic_eval_freq = int(gradual_selection_size / 1000 / 4) train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=int(gradual_selection_size), pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size, split_tokenizer=FLAGS.split_tokenizer) train_iter = iter(train_ds) # Shard data to devices and do a training step. if not FLAGS.eval_only: logging.info('Doing Training.') with jax.profiler.StepTraceAnnotation('train', step_num=step): try: batch = common_utils.shard( jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) except StopIteration: is_last_step = True # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Periodic metric handling. if (step - start_step) % dynamic_eval_freq == 0 or is_last_step: if not FLAGS.eval_only: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] if FLAGS.eval_only: p_eval_per_pos_step = jax.pmap(functools.partial( train_util.eval_per_pos_step, config=eval_config), axis_name='batch') # Get per example loss loss_filename = FLAGS.model_dir + '/test_losses.csv' train_util.write_per_example_losses( p_eval_step=p_eval_per_pos_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=FLAGS.num_eval_steps, loss_filename=loss_filename) else: with report_progress.timed('eval'): eval_results = train_util.evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=FLAGS.num_eval_steps) curr_eval_loss = eval_results['loss'] eval_loss_history.append(curr_eval_loss) if len(eval_loss_history) > 1: improvement_rate = 0.000004 orig_loss = eval_loss_history[-2] true_improvement = orig_loss - curr_eval_loss expected_improvement = ( step - last_eval_step) * improvement_rate # percent_change = (orig_loss - curr_eval_loss) / orig_loss # percent_change *= 100 if true_improvement < expected_improvement: # percent_change<.1: do_resample_data = True last_eval_step = step writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}) if FLAGS.aux_eval_dataset: for aux_i, aux_eval_ds in enumerate(aux_datasets): with report_progress.timed('aux_eval'): eval_results = train_util.evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=aux_eval_ds, num_eval_steps=FLAGS.num_eval_steps) writer.write_scalars( step, { 'aux' + str(aux_i) + '_eval_' + k: v for k, v in eval_results.items() }) if FLAGS.compute_bleu: with report_progress.timed('translate_and_bleu'): decode_file = FLAGS.model_dir + '/decodes.csv' exemplars, bleu_score = train_util.translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=FLAGS.max_predict_length, num_eval_steps=FLAGS.num_eval_steps, decode_file=decode_file if FLAGS.eval_only else '') writer.write_scalars(step, {'bleu': bleu_score}) writer.write_texts(step, {'samples': exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or is_last_step) if FLAGS.save_checkpoints and save_checkpoint and jax.process_index( ) == 0: if curr_eval_loss < best_eval_loss: # only save better checkpoints best_eval_loss = curr_eval_loss with report_progress.timed('checkpoint'): checkpoints.save_checkpoint( FLAGS.model_dir, jax_utils.unreplicate(optimizer), step, keep=FLAGS.chkpts_to_keep, overwrite=True) if is_last_step: break
def train_and_evaluate(config, workdir, strategy): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. strategy: Distribution strategy to use for distributing the model. """ tf.io.gfile.makedirs(workdir) tf_rng, data_rng = tf.random.experimental.stateless_split((config.seed, 0), 2) tf.random.set_seed(tf_rng.numpy()[0]) # Input pipeline. ds_info, train_ds, val_ds, test_ds = input_pipeline.create_datasets( config, data_rng, strategy=strategy) train_iter = iter(train_ds) # pytype: disable=wrong-arg-types # Learning rate schedule. num_train_steps = config.num_train_steps if num_train_steps == -1: num_train_steps = (ds_info.splits["train"].num_examples // config.global_batch_size * config.num_epochs) steps_per_epoch = num_train_steps // config.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) # We treat the learning rate in the config as the learning rate for batch size # 256 but scale it according to our batch size. base_learning_rate = config.learning_rate * config.global_batch_size / 256.0 # Initialize model. num_classes = ds_info.features["label"].num_classes if config.distill_teacher: do_distill = True teacher_file_list = (config.distill_teacher).split(",") teacher_models = load_teacher_models(teacher_file_list, num_classes, config, strategy) distill_params = {} distill_params["alpha"] = config.distill_alpha distill_params["beta"] = config.distill_fd_beta distill_params["teacher_model"] = TeacherModel(teacher_models, name="teacher") else: do_distill = False distill_params = None state = create_state(config, num_classes=num_classes, strategy=strategy) ckpt_manager = tf.train.CheckpointManager(checkpoint=state, directory=workdir, max_to_keep=5) if ckpt_manager.latest_checkpoint: state.restore(ckpt_manager.latest_checkpoint) logging.info("Restored from %s", ckpt_manager.latest_checkpoint) else: logging.info("Initializing from scratch.") initial_step = state.global_step.numpy().item() learning_rate_fn = functools.partial(get_learning_rate, base_learning_rate=base_learning_rate, steps_per_epoch=steps_per_epoch, num_epochs=config.num_epochs, warmup_epochs=config.warmup_epochs) writer = metric_writers.create_default_writer(workdir) writer.write_hparams(dict(config)) logging.info("Starting training loop at step %d.", initial_step) report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): state.model.trainable = True # `step` is a Python integer. `global_step` is a TF variable on the # GPU/TPU devices. is_last_step = step == num_train_steps train_step(state, train_iter, config.weight_decay, learning_rate_fn, do_distill, distill_params, strategy) state.train_metrics.update_state_lr( learning_rate_fn(state.global_step.numpy().item())) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) report_progress(step) if step == initial_step: parameter_overview.log_parameter_overview(state.model) if step % config.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, state.train_metrics.result()) state.train_metrics.reset_states() state.train_metrics.reset_lr() if step % config.eval_every_steps == 0 or is_last_step: state.model.trainable = False if config.dataset == "imagenet-lt": evaluate(state, val_ds, state.val_metrics, strategy) writer.write_scalars(step, state.val_metrics.result()) logging.info("Num val images %d", state.val_metrics.accuracy.count.numpy()) evaluate(state, test_ds, state.test_metrics, strategy) writer.write_scalars(step, state.test_metrics.result()) logging.info("Num test images %d", state.test_metrics.accuracy.count.numpy()) if step % config.checkpoint_every_steps == 0 or is_last_step: checkpoint_path = ckpt_manager.save(step) logging.info("Saved checkpoint %s", checkpoint_path) logging.info("Finishing training at step %d", step) logging.info("Saving the final weights") file_path = "%s/final_weights" % workdir state.model.save_weights(file_path, save_format="tf")
def monitor_and_sample(config, work_dir): """Monitors `work_dir` for new checkpoints and run sampling on them. Args: config: Hyperparameter configuration for training and evaluation. work_dir: Directory where the tensorboard summaries are written to. """ # Init rng key. rng = jax.random.PRNGKey(config.seed) data_rng, rng = jax.random.split(rng) is_first_host = jax.process_index() == 0 # TODO(agritsenko): We are loading the datasets just to get the metadata. # Can we be smarter about this? if config.dataset.name.endswith('speech_commands09'): _, ds_metadata = input_pipeline_sc09.get_dataset(data_rng, config) else: raise ValueError(f'Unknown dataset {config.dataset.name}.') # TODO(agritsenko): Can we fix the ugly nested dicts? config.data_shape = ds_metadata['train']['shape']['inputs'][2:] config.num_classes = ds_metadata['train']['num_classes'] config.sample_rate = ds_metadata['train']['sample_rate'] writer = metric_writers.create_default_writer( work_dir, just_logging=jax.process_index() > 0) rng, init_rng = jax.random.split(rng) model, variables = model_setup(init_rng, config) # From now on we want different rng across hosts: rng = jax.random.fold_in(rng, jax.process_index()) rng, rng_sample = jax.random.split(rng) def tx_fn(lr): return optax.adamw(lr, b1=0.9, b2=config.beta2, eps=1e-08, eps_root=0.0, weight_decay=config.weight_decay) state = language_train_state.TrainState.create(params=variables['params'], tx_fn=tx_fn) # Wait for checkpoints in an loop. ckpt_path_iterator = checkpoint.checkpoints_iterator(work_dir, target=None) with metric_writers.ensure_flushes(writer): for _ in ckpt_path_iterator: state, step = checkpoint.restore_from_path(work_dir, state) is_last_step = step == config.num_train_steps - 1 logging.info('Loaded checkpoint for step: %d', step) # Replicate the state state = flax.jax_utils.replicate(state) ######################### Run sampling ############################### chain = model.sample(jax.random.fold_in(rng_sample, step), state.ema_params, config.sample_batch_size, chain_out_size=config.get( 'chain_out_size', model.num_stages)) if is_first_host: chain = jax.device_get(chain) long_sample = np.reshape(chain[-1], (1, -1, 1)).astype(np.float32) long_sample = (2. * long_sample) / config.num_classes - 1. long_sample = long_sample.astype(np.float32) writer.write_audios(step, {'samples': long_sample}, sample_rate=config.sample_rate) if is_last_step: break
def predict_and_evaluate(config, workdir, ckpt_path=None): """Runs a testing and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. ckpt_path: The checkpoint to evaluate. If not specified, use the latest checkpoint. """ logging.info('Starting testing at %s', workdir) tf.io.gfile.makedirs(workdir) rng = jax.random.PRNGKey(config.seed) # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) test_ds = [] for split in config.dataset.test_splits: ds = input_pipeline.create_val_dataset( config.dataset, split, config.dataset.test_per_device_batch_size, config.dataset.test_pad_last_batch) test_ds.append(ds) # Initialize model. inputs = train_utils.get_init_inputs(test_ds[0]) rng, model_rng = jax.random.split(rng) predict_config = models.TransformerConfig(**config.model.to_dict()) predict_config = predict_config.replace(decode=True) model = models.Model(predict_config) state = train_utils.create_train_state(model, config, model_rng, inputs=inputs) writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) # Set up checkpointing of the model and the input pipeline. checkpoint_dir = os.path.join(workdir, 'checkpoints') ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3) logging.info('Testing and evaluating checkpoint %s', ckpt_path) try: state = ckpt.restore(state, ckpt_path) except FileNotFoundError: state = ckpt.restore_or_initialize(state) step = int(state.step) p_pred_step = jax.pmap(functools.partial(predict_step, config=predict_config), axis_name='batch', static_broadcasted_argnums=(3, )) p_init_cache = jax.pmap(functools.partial(init_cache, config=predict_config), axis_name='batch') # Distribute testing. state = flax_utils.replicate(state) with metric_writers.ensure_flushes(writer): test_metrics = {} for ds, split in zip(test_ds, config.dataset.test_splits): ds_metrics = evaluate_sequence_accuracy(p_pred_step, p_init_cache, state, ds, config, split, workdir, config.num_test_steps) ds_metrics = {f'{k}_{split}': v for k, v in ds_metrics.items()} test_metrics.update(ds_metrics) writer.write_scalars(step, test_metrics)