def make_learner( self, random_key: networks_lib.PRNGKey, networks: d4pg_networks.D4PGNetworks, dataset: Iterator[reverb.ReplaySample], logger_fn: loggers.LoggerFactory, environment_spec: specs.EnvironmentSpec, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: del environment_spec, replay_client policy_optimizer = optax.adam(self._config.learning_rate) critic_optimizer = optax.adam(self._config.learning_rate) if self._config.clipping: policy_optimizer = optax.chain(optax.clip_by_global_norm(40.), policy_optimizer) critic_optimizer = optax.chain(optax.clip_by_global_norm(40.), critic_optimizer) # The learner updates the parameters (and initializes them). return learning.D4PGLearner( policy_network=networks.policy_network, critic_network=networks.critic_network, random_key=random_key, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, clipping=self._config.clipping, discount=self._config.discount, target_update_period=self._config.target_update_period, iterator=dataset, counter=counter, logger=logger_fn('learner'), num_sgd_steps_per_step=self._config.num_sgd_steps_per_step)
def make_learner( self, random_key: networks_lib.PRNGKey, networks: impala_networks.IMPALANetworks, dataset: Iterator[reverb.ReplaySample], logger_fn: loggers.LoggerFactory, environment_spec: specs.EnvironmentSpec, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: del environment_spec, replay_client optimizer = optax.chain( optax.clip_by_global_norm(self._config.max_gradient_norm), optax.adam( self._config.learning_rate, b1=self._config.adam_momentum_decay, b2=self._config.adam_variance_decay, eps=self._config.adam_eps, eps_root=self._config.adam_eps_root)) return learning.IMPALALearner( networks=networks, iterator=dataset, optimizer=optimizer, random_key=random_key, discount=self._config.discount, entropy_cost=self._config.entropy_cost, baseline_cost=self._config.baseline_cost, max_abs_reward=self._config.max_abs_reward, counter=counter, logger=logger_fn('learner'), )
def main(): # Create the dataset. train_dataset, vocab_size = load(batch_size, sequence_length) # Set up the model, loss, and updater. forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers, dropout_rate) forward_fn = hk.transform(forward_fn) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value), optax.adam(learning_rate, b1=0.9, b2=0.99)) updater = GradientUpdater(forward_fn.init, loss_fn, optimizer) # Initialize parameters. logging.info('Initializing parameters...') rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) logging.info('Starting train loop...') prev_time = time.time() for step in range(MAX_STEPS): data = next(train_dataset) state, metrics = updater.update(state, data)
def test_batch_overfit(train_dataset): vocab_size, d_model, num_heads, num_layers = 100, 32, 8, 1 dropout_rate, grad_clip_value, learning_rate = 0.01, 0.25, 2e-2 max_iter = 100 # Set up the model, loss, and updater. forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers, dropout_rate, max_iter) forward_fn = hk.transform(forward_fn) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value), optax.adam(learning_rate, b1=0.9, b2=0.99)) updater = Updater(forward_fn.init, loss_fn, optimizer) # Initialize parameters. rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) for step in range(100): data = next(train_dataset) state, metrics = updater.update(state, data) assert metrics['loss'] < 0.1 assert metrics['step'] == 99
def main(_): # Create the dataset. train_dataset, vocab_size = dataset.load(FLAGS.batch_size, FLAGS.sequence_length) # Set up the model, loss, and updater. forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads, FLAGS.num_layers, FLAGS.dropout_rate) forward_fn = hk.transform(forward_fn) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) optimizer = optax.chain(optax.clip_by_global_norm(FLAGS.grad_clip_value), optax.adam(FLAGS.learning_rate, b1=0.9, b2=0.99)) updater = Updater(forward_fn.init, loss_fn, optimizer) updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir) # Initialize parameters. logging.info('Initializing parameters...') rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) logging.info('Starting train loop...') prev_time = time.time() for step in range(MAX_STEPS): data = next(train_dataset) state, metrics = updater.update(state, data) # We use JAX runahead to mask data preprocessing and JAX dispatch overheads. # Using values from state/metrics too often will block the runahead and can # cause these overheads to become more prominent. if step % LOG_EVERY == 0: steps_per_sec = LOG_EVERY / (time.time() - prev_time) prev_time = time.time() metrics.update({'steps_per_sec': steps_per_sec}) logging.info({k: float(v) for k, v in metrics.items()})
def make_learner( self, random_key: networks_lib.PRNGKey, networks: td3_networks.TD3Networks, dataset: Iterator[reverb.ReplaySample], replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: critic_optimizer = optax.adam(self._config.critic_learning_rate) twin_critic_optimizer = optax.adam(self._config.critic_learning_rate) policy_optimizer = optax.adam(self._config.policy_learning_rate) if self._config.policy_gradient_clipping is not None: policy_optimizer = optax.chain( optax.clip_by_global_norm( self._config.policy_gradient_clipping), policy_optimizer) return learning.TD3Learner( networks=networks, random_key=random_key, discount=self._config.discount, target_sigma=self._config.target_sigma, noise_clip=self._config.noise_clip, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, twin_critic_optimizer=twin_critic_optimizer, num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, bc_alpha=self._config.bc_alpha, iterator=dataset, logger=self._logger_fn(), counter=counter)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: networks_lib.FeedForwardNetwork, config: DQNConfig, ): """Initialize the agent.""" # Data is communicated via reverb replay. reverb_replay = replay.make_reverb_prioritized_nstep_replay( environment_spec=environment_spec, n_step=config.n_step, batch_size=config.batch_size, max_replay_size=config.max_replay_size, min_replay_size=config.min_replay_size, priority_exponent=config.priority_exponent, discount=config.discount, ) self._server = reverb_replay.server optimizer = optax.chain( optax.clip_by_global_norm(config.max_gradient_norm), optax.adam(config.learning_rate), ) key_learner, key_actor = jax.random.split( jax.random.PRNGKey(config.seed)) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, random_key=key_learner, optimizer=optimizer, discount=config.discount, importance_sampling_exponent=config.importance_sampling_exponent, target_update_period=config.target_update_period, iterator=reverb_replay.data_iterator, replay_client=reverb_replay.client, ) # The actor selects actions according to the policy. def policy(params: networks_lib.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(config.epsilon).sample( key, action_values) actor = actors.FeedForwardActor( policy=policy, rng=hk.PRNGSequence(key_actor), variable_client=variable_utils.VariableClient(learner, ''), adder=reverb_replay.adder) super().__init__( actor=actor, learner=learner, min_observations=max(config.batch_size, config.min_replay_size), observations_per_step=config.batch_size / config.samples_per_insert, )
def train(*, data_folder, batch_size, epochs, learning_rate, weight_decay, seed, max_norm, text_vocab, text_dim, text_depth, text_heads, audio_dim, audio_depth, audio_heads): # rng rng_key = random.PRNGKey(seed) # data dataset = PairTextSpectrogramDataset(data_folder) dl = DataLoader(dataset, batch_size=batch_size, collate_fn=pair_text_spectrogram_dataset_collate_fn, drop_last=True, shuffle=True) # model model = CLAP(text_vocab=text_vocab, text_dim=text_dim, text_depth=text_depth, text_heads=text_heads, audio_dim=audio_dim, audio_depth=audio_depth, audio_heads=audio_heads) # optimizer exclude_bias = lambda params: tree_util.tree_map(lambda x: x.ndim != 1, params) optim = chain(clip_by_global_norm(max_norm), scale_by_adam(eps=1e-4), add_decayed_weights(weight_decay, exclude_bias), scale(-learning_rate)) # init audio, audio_mask, text, text_mask = next(iter(dl)) params = model.init(rng_key, text, audio, text_mask, audio_mask) optim_state = optim.init(params) # loss function, for use with value_and_grad @jit @value_and_grad def loss_fn(params, text, audio, text_mask, audio_mask): return model.apply(params, text, audio, text_mask, audio_mask) # train loop for _ in range(epochs): for audio, audio_mask, text, text_mask in dl: loss, grads = loss_fn(params, text, audio, text_mask, audio_mask) updates, optim_state = optim.update(grads, optim_state, params) params = apply_updates(params, updates) print(f'loss: {loss}')
def make_learner( self, random_key: networks_lib.PRNGKey, networks: ppo_networks.PPONetworks, dataset: Iterator[reverb.ReplaySample], logger_fn: loggers.LoggerFactory, environment_spec: specs.EnvironmentSpec, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, ) -> core.Learner: del environment_spec, replay_client if callable(self._config.learning_rate): optimizer = optax.chain( optax.clip_by_global_norm(self._config.max_gradient_norm), optax.scale_by_adam(eps=self._config.adam_epsilon), optax.scale_by_schedule(self._config.learning_rate), optax.scale(-1)) else: optimizer = optax.chain( optax.clip_by_global_norm(self._config.max_gradient_norm), optax.scale_by_adam(eps=self._config.adam_epsilon), optax.scale(-self._config.learning_rate)) return learning.PPOLearner( ppo_networks=networks, iterator=dataset, discount=self._config.discount, entropy_cost=self._config.entropy_cost, value_cost=self._config.value_cost, max_abs_reward=self._config.max_abs_reward, ppo_clipping_epsilon=self._config.ppo_clipping_epsilon, clip_value=self._config.clip_value, gae_lambda=self._config.gae_lambda, counter=counter, random_key=random_key, optimizer=optimizer, num_epochs=self._config.num_epochs, num_minibatches=self._config.num_minibatches, logger=logger_fn('learner'), )
def optimizer(self, learning_rate): """Construct optimizer.""" clip = optax.clip_by_global_norm( self.config.optimizer.gradient_clip_norm) optimizer = getattr(optax, self.config.optimizer.name)( learning_rate, **self.config.optimizer.args, ) optim_step = optax.chain(clip, optimizer) optim_step = optimizers.maybe_skip_gradient_update( optim_step, self.config.optimizer.gradient_skip_norm, ) return optim_step
def create_train_state(rng, model, img_size, lr_schedule_fn, weight_decay, max_norm): tx = optax.chain(optax.clip_by_global_norm(max_norm), optax.scale_by_adam(), optax.additive_weight_decay(weight_decay), optax.scale_by_schedule(lr_schedule_fn)) params = model.init(rng, jax.numpy.ones((1, img_size, img_size, 3)), is_training=False) train_state = TrainState.create( apply_fn=model.apply, params=params, tx=tx, ) return train_state
def make_optimizer(optimizer_config, lr_schedule): """Construct the optax optimizer with given LR schedule.""" if (optimizer_config.get('decay_pos_embs') is None or optimizer_config.decay_pos_embs): # Decay learned position embeddings by default. weight_decay_exclude_names = ['b'] else: weight_decay_exclude_names = ['pos_embs', 'b'] optax_chain = [] if optimizer_config.max_norm > 0: optax_chain.append( optax.clip_by_global_norm(optimizer_config.max_norm)) if optimizer_config.optimizer == 'adam': # See: https://arxiv.org/abs/1412.6980 optax_chain.extend([ optax.scale_by_adam(**optimizer_config.adam_kwargs), add_weight_decay( optimizer_config.weight_decay, exclude_names=weight_decay_exclude_names) ]) elif optimizer_config.optimizer == 'lamb': # See: https://arxiv.org/abs/1904.00962 optax_chain.extend([ optax.scale_by_adam(**optimizer_config.lamb_kwargs), add_weight_decay( optimizer_config.weight_decay, exclude_names=weight_decay_exclude_names), optax.scale_by_trust_ratio() ]) else: raise ValueError(f'Undefined optimizer {optimizer_config.optimizer}') # Scale by the (negative) learning rate. optax_chain.extend([ optax.scale_by_schedule(lr_schedule), optax.scale(-1), ]) return optax.chain(*optax_chain)
def main(): parser = ArgumentParser() parser.add_argument('-b', '--batch-size', default=32, type=int) parser.add_argument('-d', '--rnn-hidden-size', default=256, type=int) parser.add_argument('-f', '--data-file', default='/tmp/convex_hull.dat', type=str) parser.add_argument('-l', '--lr', default=1e-3, type=float) parser.add_argument('-r', '--resume-training', default=False, action='store_true') parser.add_argument('-t', '--training-steps', default=100_000, type=int) parser.add_argument('-w', '--use-wandb', default=False, action='store_true') parser.add_argument('-wd', '--wd', default=1e-2, type=float) hparams = parser.parse_args() if hparams.use_wandb: wandb.init(project='pointer-networks', dir='/tmp') wandb.config.update(hparams) print(hparams) dataloader = ConvexHullDataLoader(data_filepath=hparams.data_file) loss_fn = partial(_loss_fn, hparams=hparams) optimizer = optax.chain(optax.adamw(hparams.lr, weight_decay=hparams.wd), optax.clip_by_global_norm(10.)) train_iter = dataloader.data_iter(hparams.batch_size, 'train') val_iter = dataloader.data_iter(hparams.batch_size, 'val') wandb_obj = wandb if hparams.use_wandb else None trainer = Trainer( train_loss_fn=loss_fn, train_data_iter=train_iter, val_loss_fn=loss_fn, val_data_iter=val_iter, optimizer=optimizer, wandb=wandb_obj, resume=hparams.resume_training ) plot_att_fn = partial(_plot_attention, hparams=hparams, wandb=wandb_obj) trainer.register_callback(1000, plot_att_fn) trainer.fit(total_steps=hparams.training_steps)
def test_graph_network_learning(self, spatial_dimension, dtype): key = random.PRNGKey(0) R_key, dr0_key, params_key = random.split(key, 3) d, _ = space.free() R = random.uniform(R_key, (6, 3, spatial_dimension), dtype=dtype) dr0 = random.uniform(dr0_key, (6, 3, 3), dtype=dtype) E_gt = vmap( lambda R, dr0: \ np.sum((space.distance(space.map_product(d)(R, R)) - dr0) ** 2)) cutoff = 0.2 init_fn, energy_fn = energy.graph_network(d, cutoff) params = init_fn(params_key, R[0]) @jit def loss(params, R): return np.mean((vmap(energy_fn, (None, 0))(params, R) - E_gt(R, dr0)) ** 2) opt = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-4)) @jit def update(params, opt_state, R): updates, opt_state = opt.update(grad(loss)(params, R), opt_state) return optax.apply_updates(params, updates), opt_state opt_state = opt.init(params) l0 = loss(params, R) for i in range(4): params, opt_state = update(params, opt_state, R) assert loss(params, R) < l0 * 0.95
def configure_update_step(learning_rate: float, loss: Callable): """Configure an optax training update step.""" opt = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(learning_rate)) @jit def _update_step(params, opt_state, positions, labels): updates, opt_state = opt.update( grad(loss)(params, positions, labels), opt_state) return optax.apply_updates(params, updates), opt_state @jit def update_step(params_and_opt_state, batches): def inner_update(params_and_opt_state, batch): params, opt_state = params_and_opt_state b_xs, b_labels = batch return _update_step(params, opt_state, b_xs, b_labels), 0 return lax.scan(inner_update, params_and_opt_state, batches)[0] return update_step, opt
def create_train_state(config, rng, learning_rate_fn, example_batch): """Create and initialize the model. Args: config: Configuration for model. rng: JAX PRNG Key. learning_rate_fn: learning rate function example_batch: for model intialization Returns: The initialized TrainState with the optimizer. """ model, variables = create_model(config, rng, example_batch) params = variables['params'] parameter_overview.log_parameter_overview(params) optimizer = optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=.98, eps=1e-9, weight_decay=config.train.weight_decay) if config.train.grad_max_norm > 0: tx = optax.chain(optax.clip_by_global_norm(config.train.grad_max_norm), optimizer) elif config.train.grad_max_val > 1: tx = optax.chain(optax.clip(config.train.grad_max_val), optimizer) else: tx = optimizer state = train_state.TrainState.create( apply_fn=model.apply, params=variables['params'], tx=tx, ) return model, state
def __init__( self, environment_spec: specs.EnvironmentSpec, forward_fn: impala_types.PolicyValueFn, unroll_init_fn: impala_types.PolicyValueInitFn, unroll_fn: impala_types.PolicyValueFn, initial_state_init_fn: impala_types.RecurrentStateInitFn, initial_state_fn: impala_types.RecurrentStateFn, config: impala_config.IMPALAConfig, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, ): networks = impala_networks.IMPALANetworks( forward_fn=forward_fn, unroll_init_fn=unroll_init_fn, unroll_fn=unroll_fn, initial_state_init_fn=initial_state_init_fn, initial_state_fn=initial_state_fn, ) self._config = config # Data is handled by the reverb replay queue. num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') key, key_initial_state = jax.random.split( jax.random.PRNGKey(self._config.seed)) params = initial_state_init_fn(key_initial_state) extra_spec = { 'core_state': initial_state_fn(params), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } reverb_queue = replay.make_reverb_online_queue( environment_spec=environment_spec, extra_spec=extra_spec, max_queue_size=self._config.max_queue_size, sequence_length=self._config.sequence_length, sequence_period=self._config.sequence_period, batch_size=self._config.batch_size, ) self._server = reverb_queue.server self._can_sample = reverb_queue.can_sample # Make the learner. optimizer = optax.chain( optax.clip_by_global_norm(self._config.max_gradient_norm), optax.adam(self._config.learning_rate), ) key_learner, key_actor = jax.random.split(key) self._learner = learning.IMPALALearner( networks=networks, iterator=reverb_queue.data_iterator, random_key=key_learner, counter=counter, logger=logger, optimizer=optimizer, discount=self._config.discount, entropy_cost=self._config.entropy_cost, baseline_cost=self._config.baseline_cost, max_abs_reward=self._config.max_abs_reward, ) # Make the actor. variable_client = variable_utils.VariableClient(self._learner, key='policy') self._actor = acting.IMPALAActor( forward_fn=jax.jit(forward_fn, backend='cpu'), initial_state_init_fn=initial_state_init_fn, initial_state_fn=initial_state_fn, rng=hk.PRNGSequence(key_actor), adder=reverb_queue.adder, variable_client=variable_client, )
def train_and_evaluate(config, workdir, vocab_filepath): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and TensorBoard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. Raises: ValueError: If training or eval batch sizes won't fit number of hosts and devices, or config is underspecified. """ # Update config before config validation. with config.unlocked(): # Numeric floating point type to use for model computations. config.dtype = jnp.float32 train_utils.validate_config(config) if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) tokenizer.SetEncodeExtraOptions("") # Note: [CLS] and [SEP] will be added by the data pipeline, not the tokenizer. with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() config.pad_id = tokenizer.pad_id() config = ml_collections.FrozenConfigDict(config) model = models.PreTrainingModel(config=config) rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) params = _init_params(model, init_rng, config) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=config.num_warmup_steps, decay_steps=config.num_train_steps - config.num_warmup_steps, ) tx = optax.adamw(learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.01) if config.clipped_grad_norm: tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm), tx) # jit state creation to ensure arrays are created on same device as input # (i.e. CPU). state_cpu = jax.jit( functools.partial(FlaxTrainState.create, apply_fn=model.apply, params=params, tx=tx))() # We access model params only via state.params del params if config.num_experts > 1: sharded_match_fn = core_utils.match_fn(r".*expert.*") not_sharded_match_fn = lambda name: not sharded_match_fn(name) else: sharded_match_fn = None not_sharded_match_fn = lambda name: True state, start_step = _restore_state_from_checkpoint(workdir, state_cpu, sharded_match_fn, not_sharded_match_fn, config) train_ds, eval_ds = _init_train_and_eval_ds(tokenizer, config) train_iter = iter(train_ds) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, jax.local_device_count()) loss_and_metrics_fn = functools.partial( _compute_loss_and_metrics, model=model, is_experts_model=config.num_experts > 1, auxiliary_loss_factor=config.auxiliary_loss_factor, router_z_loss_factor=config.router_z_loss_factor) train_step = functools.partial( train_utils.pmap_train_step, loss_and_metrics_fn=loss_and_metrics_fn, axis_name="batch", sharded_match_fn=sharded_match_fn, gradient_accum_steps=config.gradient_accum_steps) p_train_step = jax.pmap(train_step, axis_name="batch") eval_step = functools.partial(_compute_eval_stats, model=model) p_eval_step = jax.pmap(eval_step, axis_name="batch") seconds = 0. train_stats = [] logging.info("Starting training loop.") logging.info("====================") for step in range(start_step, config.num_train_steps): with jax.profiler.StepTraceContext("train", step_num=step): train_batch = next(train_iter) train_batch = common_utils.shard(train_batch) tick = time.time() state, train_step_stats, rngs = p_train_step(state, train_batch, rng=rngs) if config.measure_step_speed: jax.tree_map(lambda opt: opt.block_until_ready(), state) tock = time.time() seconds += tock - tick train_stats.append(train_step_stats) if (step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0): # We allow all hosts to potentially save checkpoints because some model # parameters are sharded across devices. Parameters replicated across # devices (i.e. not sharded) will only be checkpointed by host 0. unreplicated_state = jax.tree_map( np.array, core_utils.tree_unreplicate_by_name(state, not_sharded_match_fn)) checkpoints.save_checkpoint(workdir, unreplicated_state, sharded_match_fn, step, keep=config.checkpoints_to_keep) del unreplicated_state # Only used for checkpointing. # Periodic metric handling. if step % config.eval_frequency != 0 and step > 0: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = train_utils.collect_metrics(train_stats) train_summary = train_utils.compute_pretraining_metrics(train_metrics) train_summary["learning_rate"] = learning_rate_fn(step) if config.measure_step_speed: train_summary["steps_per_sec"] = (step - start_step + 1) / seconds if jax.process_index() == 0: assert train_summary_writer for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next training evaluation cycle. train_stats = [] logging.info("Gathering evaluation metrics at step: %d", step) eval_stats = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds): eval_batch = common_utils.shard(eval_batch) eval_stats.append(p_eval_step(state.params, eval_batch)) eval_metrics = train_utils.collect_metrics(eval_stats) eval_summary = train_utils.compute_pretraining_metrics( eval_metrics, record_grad_norm=False) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()
def create_optimizer(config): """Creates the optimizer associated to a config.""" ops = [] # Gradient clipping either by norm `gradient_norm_clip` or by absolute value # `gradient_value_clip`. if "gradient_clip" in config: raise ValueError("'gradient_clip' is deprecated, please use " "'gradient_norm_clip'.") assert not ("gradient_norm_clip" in config and "gradient_value_clip" in config), ( "Gradient clipping by norm and by value are exclusive.") if "gradient_norm_clip" in config: ops.append(optax.clip_by_global_norm(config.gradient_norm_clip)) if "gradient_value_clip" in config: ops.append(optax.clip(config.gradient_value_clip)) # Define the learning rate schedule. schedule_fn = utils.get_optax_schedule_fn( warmup_ratio=config.get("warmup_ratio", 0.), num_train_steps=config.num_train_steps, decay=config.get("learning_rate_step_decay", 1.0), decay_at_steps=config.get("learning_rate_decay_at_steps", []), cosine_decay_schedule=config.get("cosine_decay", False)) schedule_ops = [optax.scale_by_schedule(schedule_fn)] # Scale some parameters matching a regex by a multiplier. Config field # `scaling_by_regex` is a list of pairs (regex: str, multiplier: float). scaling_by_regex = config.get("scaling_learning_rate_by_regex", []) for regex, multiplier in scaling_by_regex: logging.info( "Learning rate is scaled by %f for parameters matching '%s'", multiplier, regex) schedule_ops.append(utils.scale_selected_parameters(regex, multiplier)) schedule_optimizer = optax.chain(*schedule_ops) if config.optimizer.lower() == "adam": optimizer = optax.adam(config.learning_rate) ops.append(optimizer) ops.append(schedule_optimizer) elif config.optimizer.lower() == "sgd": ops.append(schedule_optimizer) optimizer = optax.sgd(config.learning_rate, momentum=config.momentum) ops.append(optimizer) else: raise NotImplementedError("Invalid optimizer: {}".format( config.optimizer)) if "weight_decay" in config and config.weight_decay > 0.: ops.append( utils.decoupled_weight_decay(decay=config.weight_decay, step_size_fn=schedule_fn)) # Freeze parameters that match the given regexes (if any). freeze_weights_regexes = config.get("freeze_weights_regex", []) or [] if isinstance(freeze_weights_regexes, str): freeze_weights_regexes = [freeze_weights_regexes] for reg in freeze_weights_regexes: ops.append(utils.freeze(reg)) return optax.chain(*ops)
def __init__( self, environment_spec: specs.EnvironmentSpec, forward_fn: networks.PolicyValueRNN, unroll_fn: networks.PolicyValueRNN, initial_state_fn: Callable[[], hk.LSTMState], sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, seed: int = 0, max_abs_reward: float = np.inf, max_gradient_norm: float = np.inf, ): # Data is handled by the reverb replay queue. num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') extra_spec = { 'core_state': hk.without_apply_rng(hk.transform(initial_state_fn)).apply(None), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } reverb_queue = replay.make_reverb_online_queue( environment_spec=environment_spec, extra_spec=extra_spec, max_queue_size=max_queue_size, sequence_length=sequence_length, sequence_period=sequence_period, batch_size=batch_size, ) self._server = reverb_queue.server self._can_sample = reverb_queue.can_sample # Make the learner. optimizer = optax.chain( optax.clip_by_global_norm(max_gradient_norm), optax.adam(learning_rate), ) key_learner, key_actor = jax.random.split(jax.random.PRNGKey(seed)) self._learner = learning.IMPALALearner( obs_spec=environment_spec.observations, unroll_fn=unroll_fn, initial_state_fn=initial_state_fn, iterator=reverb_queue.data_iterator, random_key=key_learner, counter=counter, logger=logger, optimizer=optimizer, discount=discount, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_abs_reward=max_abs_reward, ) # Make the actor. variable_client = variable_utils.VariableClient(self._learner, key='policy') transformed = hk.without_apply_rng(hk.transform(forward_fn)) self._actor = acting.IMPALAActor( forward_fn=jax.jit(transformed.apply, backend='cpu'), initial_state_fn=initial_state_fn, rng=hk.PRNGSequence(key_actor), adder=reverb_queue.adder, variable_client=variable_client, )
def main(argv): """Trains Rainbow agent on Atari.""" del argv logging.info('Rainbow on Atari on %s.', jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Atari environment.""" env = gym_atari.GymAtari( FLAGS.environment_name, seed=random_state.randint(1, 2**32)) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info('Environment: %s', FLAGS.environment_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) num_actions = env.action_spec().num_values support = jnp.linspace(-FLAGS.vmax, FLAGS.vmax, FLAGS.num_atoms) network_fn = networks.rainbow_atari_network(num_actions, support, FLAGS.noisy_weight_init) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation chex.assert_shape(sample_network_input, (FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames)) # Note the t in the replay is not exactly aligned with the agent t. importance_sampling_exponent_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity), end_t=(FLAGS.num_iterations * int(FLAGS.num_train_frames / FLAGS.num_action_repeats)), begin_value=FLAGS.importance_sampling_exponent_begin_value, end_value=FLAGS.importance_sampling_exponent_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_structure = replay_lib.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) transition_accumulator = replay_lib.NStepTransitionAccumulator(FLAGS.n_steps) replay = replay_lib.PrioritizedTransitionReplay( FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent, importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability, FLAGS.normalize_weights, random_state, encoder, decoder) optimizer = optax.adam( learning_rate=FLAGS.learning_rate, eps=FLAGS.optimizer_epsilon) if FLAGS.max_global_grad_norm > 0: optimizer = optax.chain( optax.clip_by_global_norm(FLAGS.max_global_grad_norm), optimizer) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.Rainbow( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, support=support, optimizer=optimizer, transition_accumulator=transition_accumulator, replay=replay, batch_size=FLAGS.batch_size, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=0, rng_key=eval_rng_key, ) # Set up checkpointing. checkpoint = parts.NullCheckpoint() state = checkpoint.state state.iteration = 0 state.train_agent = train_agent state.eval_agent = eval_agent state.random_state = random_state state.writer = writer if checkpoint.can_be_restored(): checkpoint.restore() while state.iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info('Training iteration %d.', state.iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_trackers = parts.make_default_trackers(train_agent) train_stats = parts.generate_statistics(train_trackers, train_seq_truncated) logging.info('Evaluation iteration %d.', state.iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_trackers = parts.make_default_trackers(eval_agent) eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats['episode_return']) capped_human_normalized_score = np.amin([1., human_normalized_score]) log_output = [ ('iteration', state.iteration, '%3d'), ('frame', state.iteration * FLAGS.num_train_frames, '%5d'), ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'), ('train_episode_return', train_stats['episode_return'], '% 2.2f'), ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'), ('train_frame_rate', train_stats['step_rate'], '%4.0f'), ('train_state_value', train_stats['state_value'], '%.3f'), ('importance_sampling_exponent', train_agent.importance_sampling_exponent, '%.3f'), ('max_seen_priority', train_agent.max_seen_priority, '%.3f'), ('normalized_return', human_normalized_score, '%.3f'), ('capped_normalized_return', capped_human_normalized_score, '%.3f'), ('human_gap', 1. - capped_human_normalized_score, '%.3f'), ] log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) state.iteration += 1 checkpoint.save() writer.close()
def train_model(train_file_pattern, test_file_pattern, max_files_to_load=None, n_epochs=1000, time_index=9, learning_rate=1e-4, grad_clip=1.0, measurement_store_interval=1000, checkpoint_path=None): """Trains GraphModel using tensorflow. Args: train_file_pattern: pattern matching the files with the training data. test_file_pattern: pattern matching the files with the test data. max_files_to_load: the maximum number of train and test files to load. If None, all files will be loaded. n_epochs: the number of passes through the training dataset (epochs). time_index: the time index (0-9) of the target mobilities. learning_rate: the learning rate used by the optimizer. grad_clip: all gradients are clipped to the given value. measurement_store_interval: number of steps between storing objective values (loss and correlation). checkpoint_path: ignored by this implementation. """ if checkpoint_path: logging.warning('The checkpoint_path argument is ignored.') random.seed(42) np.random.seed(42) # Loads train and test dataset. dataset_kwargs = dict(time_index=time_index, max_files_to_load=max_files_to_load) logging.info('Load training data') training_data = load_data(train_file_pattern, **dataset_kwargs) logging.info('Load test data') test_data = load_data(test_file_pattern, **dataset_kwargs) logging.info('Finished loading data') network = hk.without_apply_rng(hk.transform(network_definition)) params = network.init(jax.random.PRNGKey(42), training_data[0][0]) opt_init, opt_update = optax.chain(optax.clip_by_global_norm(grad_clip), optax.scale_by_adam(0.9, 0.999, 1e-8), optax.scale(-learning_rate)) opt_state = opt_init(params) network_apply = jax.jit(network.apply) @jax.jit def loss_fn(params, graph, targets, mask): decoded_nodes = network_apply(params, graph) * mask return (jnp.sum((decoded_nodes - targets)**2 * mask) / jnp.sum(mask)) @jax.jit def update(params, opt_state, graph, targets, mask): loss, grads = jax.value_and_grad(loss_fn)(params, graph, targets, mask) updates, opt_state = opt_update(grads, opt_state) return optax.apply_updates(params, updates), opt_state, loss train_stats = [] i = 0 logging.info('Start training') for epoch in range(n_epochs): logging.info('Start epoch %r', epoch) random.shuffle(training_data) for graph, targets, mask in training_data: graph = apply_random_rotation(graph) params, opt_state, loss = update(params, opt_state, graph, targets, mask) train_stats.append(loss) if (i + 1) % measurement_store_interval == 0: logging.info('Start evaluation run') test_stats = [] for test_graph, test_targets, test_mask in test_data: predictions = network_apply(params, test_graph) test_stats.append( np.corrcoef(predictions[test_mask == 1], test_targets[test_mask == 1])[0, 1]) logging.info('Train loss %r', np.mean(train_stats)) logging.info('Test correlation %r', np.mean(test_stats)) train_stats = [] i += 1
(mel1_hat, mel2_hat), new_aux = (net if is_training else val_net).apply(params, aux, rng, inputs) loss1 = (jnp.square(mel1_hat - mels) + jnp.square(mel2_hat - mels)) / 2 loss2 = (jnp.abs(mel1_hat - mels) + jnp.abs(mel2_hat - mels)) / 2 loss = jnp.mean((loss1 + loss2)/2, axis=-1) mask = (jnp.arange(0, L)[None, :] - 10) < (inputs.wav_lengths // (FLAGS.n_fft // 4))[:, None] loss = jnp.sum(loss * mask) / jnp.sum(mask) return (loss, new_aux) if is_training else (loss, new_aux, mel2_hat, mels) train_loss_fn = partial(loss_fn, is_training=True) val_loss_fn = jax.jit(partial(loss_fn, is_training=False)) loss_vag = jax.value_and_grad(train_loss_fn, has_aux=True) optimizer = optax.chain( optax.clip_by_global_norm(1.0), optax.adam(FLAGS.learning_rate) ) @jax.jit def update(params, aux, rng, optim_state, inputs): rng, new_rng = jax.random.split(rng) (loss, new_aux), grads = loss_vag(params, aux, rng, inputs) updates, new_optim_state = optimizer.update(grads, optim_state, params) new_params = optax.apply_updates(updates, params) return loss, (new_params, new_aux, new_rng, new_optim_state) def initial_state(batch): rng = jax.random.PRNGKey(42)
forward_fn = jax.jit( hk.transform_with_state(lambda x: DurationModel(is_training=False) (x)).apply) def predict_duration(params, aux, rng, x: DurationInput): d, _ = forward_fn(params, aux, rng, x) return d, x.durations val_loss_fn = jax.jit(partial(loss_fn, is_training=False)) loss_vag = jax.value_and_grad(loss_fn, has_aux=True) optimizer = optax.chain(optax.clip_by_global_norm(FLAGS.max_grad_norm), optax.adam(FLAGS.learning_rate)) @jax.jit def update(params, aux, rng, optim_state, inputs: DurationInput): rng, new_rng = jax.random.split(rng) (loss, new_aux), grads = loss_vag(params, aux, rng, inputs) updates, new_optim_state = optimizer.update(grads, optim_state, params) new_params = optax.apply_updates(params, updates) return loss, (new_params, new_aux, new_rng, new_optim_state) def initial_state(batch): rng = jax.random.PRNGKey(42) params, aux = hk.transform_with_state(lambda x: DurationModel(True)
def main(argv): """ Train pick-up and drop-off Rainbow agents on ODySSEUS. """ del argv # Unused arguments # Metadata configuration parent_dir = pathlib.Path(__file__).parent.absolute() sim_input_conf_dir = parent_dir / 'configs' / DEFAULT_sim_scenario_name # Load configuration sim_conf = importlib.import_module('esbdqn.configs.{}.{}' .format(DEFAULT_sim_scenario_name, FLAGS.conf_filename)) # Extract a single conf pair sim_general_conf = EFFCS_SimConfGrid(sim_conf.General) \ .conf_list[0] sim_scenario_conf = EFFCS_SimConfGrid(sim_conf.Multiple_runs) \ .conf_list[0] experiment_dir = parent_dir \ / 'experiments' \ / DEFAULT_sim_scenario_name \ / FLAGS.exp_name \ / sim_general_conf['city'] if pathlib.Path.exists(experiment_dir): # Ensure configuration has not changed if not filecmp.cmp(str(sim_input_conf_dir / FLAGS.conf_filename) + '.py', str(experiment_dir / DEFAULT_conf_filename) + ".py", shallow=False): raise IOError('Configuration changed at: {}' .format(str(experiment_dir))) else: pathlib.Path.mkdir(experiment_dir, parents=True, exist_ok=True) # Copy configuration files shutil.rmtree(experiment_dir) shutil.copytree(sim_input_conf_dir, experiment_dir) # Rename to the default name conf_filepath = experiment_dir / (FLAGS.conf_filename + ".py") conf_filepath.rename(experiment_dir / (DEFAULT_conf_filename + ".py")) # Delete all other potential conf files for filename in experiment_dir.glob( DEFAULT_conf_filename + "_*.py"): filename.unlink() # Create results files results_dir = experiment_dir / 'results' pathlib.Path.mkdir(results_dir, parents=True, exist_ok=True) results_filepath = results_dir / DEFAULT_resu_filename logging.info('Rainbow agents on ODySSEUS running on %s.', jax.lib.xla_bridge.get_backend().platform.upper()) if FLAGS.checkpoint: checkpoint = PickleCheckpoint( experiment_dir / 'models', 'ODySSEUS-' + sim_general_conf['city']) else: checkpoint = parts.NullCheckpoint() checkpoint_restored = False if FLAGS.checkpoint: if checkpoint.can_be_restored(): logging.info('Restoring checkpoint...') checkpoint.restore() checkpoint_restored = True # Generate RNG key rng_state = np.random.RandomState(FLAGS.seed) if checkpoint_restored: rng_state.set_state(checkpoint.state .rng_state) rng_key = jax.random.PRNGKey( rng_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64)) # Generate results file writer if sim_general_conf['save_history']: writer = parts.CsvWriter(str(results_filepath)) if checkpoint_restored: writer.set_state(checkpoint.state .writer) else: writer = parts.NullWriter() def environment_builder() -> ConstrainedEnvironment: """ Create the ODySSEUS environment. """ return EscooterSimulator( (sim_general_conf, sim_scenario_conf), FLAGS.n_lives) def preprocessor_builder(): """ Create the ODySSEUS input preprocessor. """ return processor( max_abs_reward=FLAGS.max_abs_reward, zero_discount_on_life_loss=True ) env = environment_builder() logging.info('Environment: %s', FLAGS.exp_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) # Take [0] as both Rainbow have # the same number of actions num_actions = env.action_spec()[0].num_values support = jnp.linspace(-FLAGS.vmax, FLAGS.vmax, FLAGS.num_atoms) network = hk.transform(rainbow_odysseus_network( num_actions, support, FLAGS.noisy_weight_init)) # Create sample network input from reset. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = t.cast(dm_env.TimeStep, sample_processed_timestep) sample_processed_network_input = sample_processed_timestep.observation # Note the t in the replay is not exactly # aligned with the Rainbow agents t. importance_sampling_exponent_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity), end_t=(FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.importance_sampling_exponent_begin_value, end_value=FLAGS.importance_sampling_exponent_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay.compress_array(transition.s_tm1), s_t=replay.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay.uncompress_array(transition.s_tm1), s_t=replay.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_struct = replay.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) transition_accumulator = replay.NStepTransitionAccumulator(FLAGS.n_steps) transition_replay = replay.PrioritizedTransitionReplay( FLAGS.replay_capacity, replay_struct, FLAGS.priority_exponent, importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability, FLAGS.normalize_weights, rng_state, encoder, decoder) optimizer = optax.adam( learning_rate=FLAGS.learning_rate, eps=FLAGS.optimizer_epsilon) if FLAGS.max_global_grad_norm > 0: optimizer = optax.chain( optax.clip_by_global_norm( FLAGS.max_global_grad_norm), optimizer) train_rng_key, eval_rng_key = jax.random.split(rng_key) # Create pick-up/drop-off agents P_train_agent = agent.Rainbow( preprocessor=preprocessor_builder(), sample_network_input=copy.deepcopy(sample_processed_network_input), network=copy.deepcopy(network), support=copy.deepcopy(support), optimizer=copy.deepcopy(optimizer), transition_accumulator=copy.deepcopy(transition_accumulator), replay=copy.deepcopy(transition_replay), batch_size=FLAGS.batch_size, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, rng_key=train_rng_key, ) D_train_agent = agent.Rainbow( preprocessor=preprocessor_builder(), sample_network_input=copy.deepcopy(sample_processed_network_input), network=copy.deepcopy(network), support=copy.deepcopy(support), optimizer=copy.deepcopy(optimizer), transition_accumulator=copy.deepcopy(transition_accumulator), replay=copy.deepcopy(transition_replay), batch_size=FLAGS.batch_size, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, rng_key=train_rng_key, ) P_eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=copy.deepcopy(network), exploration_epsilon=0, rng_key=eval_rng_key, ) D_eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=copy.deepcopy(network), exploration_epsilon=0, rng_key=eval_rng_key, ) if checkpoint_restored: P_train_agent.set_state(checkpoint.state.P_agent['train']) D_train_agent.set_state(checkpoint.state.D_agent['train']) P_eval_agent.set_state(checkpoint.state.P_agent['eval']) D_eval_agent.set_state(checkpoint.state.D_agent['eval']) state = checkpoint.state if not checkpoint_restored: state.iteration = 0 state.P_agent = {} state.D_agent = {} state.rng_state = rng_state state.writer = writer state.P_agent['train'] = P_train_agent state.D_agent['train'] = D_train_agent state.P_agent['eval'] = P_eval_agent state.D_agent['eval'] = D_eval_agent while state.iteration < FLAGS.num_iterations: # Create a new environment at each new iteration # to allow for determinism if preempted. env = environment_builder() # Leave some spacing print('\n') logging.info('Training iteration: %d', state.iteration) train_trackers = make_odysseus_trackers(FLAGS.max_abs_reward) eval_trackers = make_odysseus_trackers(FLAGS.max_abs_reward) train_seq = run_loop(P_train_agent, D_train_agent, env, FLAGS.max_steps_per_episode) num_train_frames = 0 \ if state.iteration == 0 \ else FLAGS.num_train_frames train_seq_truncated = it.islice(train_seq, num_train_frames) train_stats = generate_statistics(train_trackers, train_seq_truncated) logging.info('Evaluation iteration: %d', state.iteration) # Synchronize network parameters P_eval_agent.network_params = P_train_agent.online_params D_eval_agent.network_params = P_train_agent.online_params eval_seq = run_loop(P_eval_agent, D_eval_agent, env, FLAGS.max_steps_per_episode) eval_seq_truncated = it.islice(eval_seq, FLAGS.num_eval_frames) eval_stats = generate_statistics(eval_trackers, eval_seq_truncated) # Logging and checkpointing L = [ # Simulation metadata ('iteration', state.iteration, '%3d'), # ODySSEUS metadata ('n_charging_workers', sim_scenario_conf['n_workers'], '%3d'), ('n_relocation_workers', sim_scenario_conf['n_relocation_workers'], '%3d'), ('n_vehicles', sim_scenario_conf['n_vehicles'], '%3d'), ('pct_incentive_willingness', sim_scenario_conf['incentive_willingness'], '%2.2f'), ('zone_side_m', sim_general_conf['bin_side_length'], '%3d'), # Validation agents ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('eval_P_episode_return', eval_stats['episode_return'][0], '%2.2f'), ('eval_D_episode_return', eval_stats['episode_return'][1], '%2.2f'), ('eval_min_n_accepted_incentives', np.min(eval_stats['episodes_n_accepted_incentives']), '%2.2f'), ('eval_avg_n_accepted_incentives', np.mean(eval_stats['episodes_n_accepted_incentives']), '%2.2f'), ('eval_max_n_accepted_incentives', np.max(eval_stats['episodes_n_accepted_incentives']), '%2.2f'), ('eval_min_n_lives', np.min(eval_stats['episodes_n_lives']), '%2.2f'), ('eval_avg_n_lives', np.mean(eval_stats['episodes_n_lives']), '%2.2f'), ('eval_max_n_lives', np.max(eval_stats['episodes_n_lives']), '%2.2f'), ('eval_min_pct_satisfied_demand', np.min(eval_stats['pct_satisfied_demands']), '%2.2f'), ('eval_avg_pct_satisfied_demand', np.mean(eval_stats['pct_satisfied_demands']), '%2.2f'), ('eval_max_pct_satisfied_demand', np.max(eval_stats['pct_satisfied_demands']), '%2.2f'), # Training agents ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('train_P_episode_return', train_stats['episode_return'][0], '%2.2f'), ('train_D_episode_return', train_stats['episode_return'][1], '%2.2f'), ('train_min_n_accepted_incentives', np.min(train_stats['episodes_n_accepted_incentives']), '%2.2f'), ('train_avg_n_accepted_incentives', np.mean(train_stats['episodes_n_accepted_incentives']), '%2.2f'), ('train_max_n_accepted_incentives', np.max(train_stats['episodes_n_accepted_incentives']), '%2.2f'), ('train_min_n_lives', np.min(train_stats['episodes_n_lives']), '%2.2f'), ('train_avg_n_lives', np.mean(train_stats['episodes_n_lives']), '%2.2f'), ('train_mac_n_lives', np.max(train_stats['episodes_n_lives']), '%2.2f'), ('train_min_pct_satisfied_demand', np.min(train_stats['pct_satisfied_demands']), '%2.2f'), ('train_avg_pct_satisfied_demand', np.mean(train_stats['pct_satisfied_demands']), '%2.2f'), ('train_max_pct_satisfied_demand', np.max(train_stats['pct_satisfied_demands']), '%2.2f'), ('P_importance_sampling_exponent', P_train_agent.importance_sampling_exponent, '%.3f'), ('D_importance_sampling_exponent', D_train_agent.importance_sampling_exponent, '%.3f'), ('P_max_seen_priority', P_train_agent.max_seen_priority, '%.3f'), ('D_max_seen_priority', D_train_agent.max_seen_priority, '%.3f'), ] L_str = '\n'.join(('%s: ' + f) % (n, v) for n, v, f in L) logging.info(L_str) if state.iteration == \ FLAGS.num_iterations - 1: print('\n') writer.write(collections.OrderedDict( (n, v) for n, v, _ in L)) state.iteration += 1 if state.iteration \ % FLAGS.checkpoint_period == 0: checkpoint.save() writer.close()
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" os.environ["JAX_DEBUG_NANS"] = "True" from swarm_jax.swarm_layer import NetworkPrecision from loader import TextLoader from swarm_jax.model import SwarmCharTransformer from swarm_jax.swarm import Swarm import ray import optax ray.init(resources={"tpu": 999}) # pretend we have infinite tpus lol train_dataset = TextLoader("data/enwik8", batchsize=(1, 16), sample_size=128, length=90000000) optimizer = optax.chain(optax.clip_by_global_norm(0.25), optax.adam(2e-4, b1=0.9, b2=0.99, eps=1e-5)) prec = NetworkPrecision(fwd_act="uint16", rev_act="uint16", grad="uint16") model = SwarmCharTransformer swarm = Swarm(model, optimizer, 2**16, train_dataset.get_samples, prec) swarm.run(100000, "runs/512_30L", "ckpt/512_30L") ray.shutdown()
def train(config: ml_collections.ConfigDict): """Run training.""" # Establish host information local_device_count = jax.local_device_count() host_count = jax.process_count() host_id = jax.process_index() task = task_registry.get_registered_task(config.task_name) start_step = 0 rng = jax.random.PRNGKey(config.seed) model_config = ml_collections.FrozenConfigDict(config.model_config) model = task.build_model(model_config) # Initialization needs to be pmapped because models use collective ops. # Create dummy input dummy_input = { key: jnp.tile(value, (local_device_count,) + (1,) * value.ndim) for key, value in task.dummy_input(config).items() } rng, init_rng = jax.random.split(rng) init_rng = jax.random.split(init_rng, local_device_count) logging.info('Initializing model.') initial_variables = jax.pmap( model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input, True) logging.info('Finished initializing model.') initial_variables = initial_variables.unfreeze() if config.load_weights is not None: logging.info('Loading model weights from file') loaded_variables = task.load_weights(config) unexpected, missing = checkpoint_utils.merge_nested_dicts( initial_variables, loaded_variables) logging.info('*** Unexpected features: ***') for feature_name in unexpected: logging.info('\t%s', feature_name) logging.info('*** Missing features: ***') for feature_name in missing: logging.info('\t%s', feature_name) model_vars = { key: value for key, value in initial_variables.items() if key != 'params' } learning_rate_fn = optim_utils.create_learning_rate_scheduler( learning_rate=config.learning_rate, warmup=config.warmup, warmup_steps=config.get('warmup_steps', None), linear_decay=config.linear_decay, max_steps=config.num_train_steps, decay_minimum_factor=config.get('decay_minimum_factor', None), ) if config.weight_decay_exclude is not None: decay_mask = optim_utils.create_dict_mask(initial_variables['params'], config.weight_decay_exclude) else: decay_mask = None tx = optax.adamw( learning_rate=learning_rate_fn, weight_decay=config.weight_decay, b1=0.9, b2=0.999, eps=1e-6, mask=decay_mask) if config.grad_clip is not None: tx = optax.chain(tx, optax.clip_by_global_norm(config.grad_clip)) ignore_k_nans = config.get('ignore_k_nans') if ignore_k_nans is not None: tx = optax.apply_if_finite(tx, max_consecutive_errors=ignore_k_nans) loss_fn = task.make_loss_fn(config) train_state = ts.TrainState.create( apply_fn=loss_fn, params=jax_utils.unreplicate(initial_variables['params']), tx=tx, ) # We access model params only from train state. del initial_variables # Restore unreplicated train state from last checkpoint train_state = checkpoints.restore_checkpoint(config.model_dir, train_state) # Grab last step. start_step = int(train_state.step) writer = metric_writers.create_default_writer( config.model_dir, just_logging=jax.process_index() > 0) if start_step == 0: writer.write_hparams(config.to_dict()) dropout_rngs = jax.random.split(rng, local_device_count) del rng # Load datasets logging.info('Loading dataset.') # Make sure we don't re-use same data if we load weights or checkpoint seed = config.seed + start_step if config.load_weights: seed = seed + hash(config.load_weights) name_to_features = task.get_name_to_features(config) preprocess_fn = task.make_preprocess_fn(config) collater_fn = task.make_collater_fn(config) train_data = data_utils.load_multi_dataset( datasets_config=config.train_data, name_to_features=name_to_features, preprocess_fn=preprocess_fn, collater_fn=collater_fn, is_training=True, per_device_batch_size=config.per_device_batch_size, local_device_count=local_device_count, host_count=host_count, host_id=host_id, seed=config.seed, ) train_iter = iter(train_data) pad_eval = config.get('pad_eval', False) if pad_eval: logging.info('Eval data is padded such that none of samples are dropped.') else: logging.warn('Eval data is NOT padded -- some samples might be dropped.') eval_data = data_utils.load_multi_dataset( datasets_config=config.eval_data, name_to_features=name_to_features, preprocess_fn=preprocess_fn, collater_fn=collater_fn, is_training=False, per_device_batch_size=config.per_device_batch_size, local_device_count=local_device_count, host_count=host_count, host_id=host_id, seed=config.seed, pad_eval=pad_eval, ) eval_data = list(eval_data) logging.info('Loaded %d samples for evaluation.', len(eval_data)) # Setup postprocessing_fn for saving samples occasionally. if config.get('save_samples_every_steps') is not None: if config.get('save_samples_every_steps') % config.eval_every_steps != 0: raise ValueError( '`eval_every_steps` must divide `save_samples_every_steps`.') postprocessing_fn = task.make_output_postprocess_fn(config) # Training loop logging.info('Starting training.') # Replicate train state. train_state = jax_utils.replicate(train_state) # compile multidevice versions of train/eval/predict step p_train_step = jax.pmap( functools.partial( train_step, model_config=model_config, ), axis_name='batch', donate_argnums=(0,), ) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap( functools.partial( eval_step, model_config=model_config, ), axis_name='batch') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=config.model_dir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and perform a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): batch = jax.tree_map(jnp.asarray, train_iter.get_next()) train_state, metrics = p_train_step( train_state, model_vars, batch, dropout_rngs, ) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) metrics_sums = jax.tree_map(jnp.sum, train_metrics) summary = metric_utils.process_metrics(metrics_sums, prefix='train') summary['learning_rate'] = learning_rate_fn(step) writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed('eval'): eval_results, eval_auxiliary = evaluate( eval_step_fn=p_eval_step, train_state=train_state, model_vars=model_vars, eval_data=eval_data, ) writer.write_scalars(step, eval_results) if config.get('save_samples_every_steps') is not None: with report_progress.timed('save_samples'): if config.get('save_first_batch_only', 'True'): postprocessing_input = [eval_auxiliary[0]] eval_processed = [ postprocessing_fn(batch, auxiliary_output) for batch, auxiliary_output in eval_auxiliary ] data_utils.save_samples_to_json(eval_processed, config, step) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( step % config.checkpoint_every_steps == 0 or is_last_step) if (config.save_checkpoints and save_checkpoint and jax.process_index() == 0): with report_progress.timed('checkpoint'): logging.info('Saving checkpoint at step %s', step) checkpoints.save_checkpoint( config.model_dir, jax_utils.unreplicate(train_state), step, keep=config.get('keep_checkpoints', 1), keep_every_n_steps=config.get('keep_checkpoint_every_steps'), ) save_model = ( config.save_every_steps and (step % config.save_every_steps == 0 or is_last_step) and step != 0) if (save_model and jax.process_index() == 0): with report_progress.timed('checkpoint'): logging.info('Saving weights at step %s', step) save_path = os.path.join(config.model_dir, 'weights', 'step' + str(step)) # By default, save only encoder weights weights = jax_utils.unreplicate(train_state).params['encoder'] checkpoint_utils.save_weights(save_path, weights)
def __init__( self, environment_spec: specs.EnvironmentSpec, forward_fn: networks.PolicyValueRNN, unroll_fn: networks.PolicyValueRNN, initial_state_fn: Callable[[], hk.LSTMState], sequence_length: int, sequence_period: int, counter: counting.Counter = None, logger: loggers.Logger = None, discount: float = 0.99, max_queue_size: int = 100000, batch_size: int = 16, learning_rate: float = 1e-3, entropy_cost: float = 0.01, baseline_cost: float = 0.5, seed: int = 0, max_abs_reward: float = np.inf, max_gradient_norm: float = np.inf, ): num_actions = environment_spec.actions.num_values self._logger = logger or loggers.TerminalLogger('agent') extra_spec = { 'core_state': hk.without_apply_rng(hk.transform(initial_state_fn, apply_rng=True)).apply(None), 'logits': np.ones(shape=(num_actions, ), dtype=np.float32) } signature = adders.SequenceAdder.signature(environment_spec, extra_spec) queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE, max_size=max_queue_size, signature=signature) self._server = reverb.Server([queue], port=None) self._can_sample = lambda: queue.can_sample(batch_size) address = f'localhost:{self._server.port}' # Component to add things into replay. adder = adders.SequenceAdder( client=reverb.Client(address), period=sequence_period, sequence_length=sequence_length, ) # The dataset object to learn from. # We don't use datasets.make_reverb_dataset() here to avoid interleaving # and prefetching, that doesn't work well with can_sample() check on update. dataset = reverb.ReplayDataset.from_table_signature( server_address=address, table=adders.DEFAULT_PRIORITY_TABLE, max_in_flight_samples_per_worker=1, sequence_length=sequence_length, emit_timesteps=False) dataset = dataset.batch(batch_size, drop_remainder=True) optimizer = optax.chain( optax.clip_by_global_norm(max_gradient_norm), optax.adam(learning_rate), ) self._learner = learning.IMPALALearner( obs_spec=environment_spec.observations, unroll_fn=unroll_fn, initial_state_fn=initial_state_fn, iterator=dataset.as_numpy_iterator(), rng=hk.PRNGSequence(seed), counter=counter, logger=logger, optimizer=optimizer, discount=discount, entropy_cost=entropy_cost, baseline_cost=baseline_cost, max_abs_reward=max_abs_reward, ) variable_client = variable_utils.VariableClient(self._learner, key='policy') self._actor = acting.IMPALAActor( forward_fn=jax.jit(hk.without_apply_rng( hk.transform(forward_fn, apply_rng=True)).apply, backend='cpu'), initial_state_fn=initial_state_fn, rng=hk.PRNGSequence(seed), adder=adder, variable_client=variable_client, )
forward_fn = jax.jit(hk.transform_with_state(lambda x: DurationModel(is_training=False)(x)).apply) def predict_duration(params, aux, rng, x: DurationInput): d, _ = forward_fn(params, aux, rng, x) return d, x.durations val_loss_fn = jax.jit(partial(loss_fn, is_training=False)) loss_vag = jax.value_and_grad(loss_fn, has_aux=True) optimizer = optax.chain( optax.clip_by_global_norm(FLAGS.max_grad_norm), optax.adamw(FLAGS.duration_learning_rate, weight_decay=FLAGS.weight_decay) ) @jax.jit def update(params, aux, rng, optim_state, inputs: DurationInput): rng, new_rng = jax.random.split(rng) (loss, new_aux), grads = loss_vag(params, aux, rng, inputs) updates, new_optim_state = optimizer.update(grads, optim_state, params) new_params = optax.apply_updates(params, updates) return loss, (new_params, new_aux, new_rng, new_optim_state) def initial_state(batch): rng = jax.random.PRNGKey(42)
def train_and_evaluate(config, workdir, vocab_filepath): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and TensorBoard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. Raises: ValueError: If training or eval batch sizes won't fit number of hosts and devices, or config is underspecified. """ # Update config before config validation. with config.unlocked(): # Numeric floating point type to use for model computations. config.dtype = jnp.float32 train_utils.validate_config(config) per_host_train_batch_size = config.train_batch_size // jax.process_count() per_host_eval_batch_size = config.eval_batch_size // jax.process_count() if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) ds_info = tfds.builder(config.dataset_name).info num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples num_train_steps = int(num_train_examples * config.num_train_epochs // config.train_batch_size) num_warmup_steps = int(config.warmup_proportion * num_train_steps) # Round up evaluation frequency to power of 10. eval_frequency = int( math.ceil(config.eval_proportion * num_train_steps / 10)) * 10 # STSB is a regression task. COPA and ReCoRD are treated as scalar/regression # tasks during training. is_regression_task = (config.dataset_name == "glue/stsb" or config.dataset_name == "super_glue/copa" or config.dataset_name == "super_glue/record") if is_regression_task: num_classes = 1 else: num_classes = ds_info.features["label"].num_classes with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() config.pad_id = tokenizer.pad_id() config = ml_collections.FrozenConfigDict(config) model = models.SequenceClassificationModel(config, num_classes) rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) params = _init_params(model, init_rng, config) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=num_warmup_steps, decay_steps=num_train_steps - num_warmup_steps, ) tx = optax.adamw(learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.01) if config.clipped_grad_norm: tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm), tx) # jit state creation to ensure arrays are created on same device as input # (i.e. CPU). state_cpu = jax.jit( functools.partial(FlaxTrainState.create, apply_fn=model.apply, params=params, tx=tx))() # We access model params only via state.params del params if config.num_experts > 1: sharded_match_fn = core_utils.match_fn(r".*expert.*") not_sharded_match_fn = lambda name: not sharded_match_fn(name) else: sharded_match_fn = None not_sharded_match_fn = lambda name: True state, start_step = _restore_state_from_checkpoint(workdir, state_cpu, sharded_match_fn, not_sharded_match_fn, config) if is_regression_task: scoring_fn = lambda y: y[Ellipsis, 0] else: scoring_fn = lambda y: y.argmax(-1) compute_stats = functools.partial(_compute_stats, model=model, scoring_fn=scoring_fn) classification_inputs = functools.partial( input_pipeline.classification_inputs, dataset_name=config.dataset_name, max_seq_length=config.max_seq_length, tokenizer=tokenizer) train_ds = classification_inputs(split=tfds.Split.TRAIN, batch_size=per_host_train_batch_size, training=True) train_iter = iter(train_ds) if config.dataset_name == "glue/mnli": # MNLI contains two validation and test datasets. split_suffixes = ["_matched", "_mismatched"] else: split_suffixes = [""] # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, jax.local_device_count()) loss_and_metrics_fn = functools.partial( _compute_loss_and_metrics, model=model, is_experts_model=config.num_experts > 1, auxiliary_loss_factor=config.auxiliary_loss_factor, router_z_loss_factor=config.router_z_loss_factor) train_step = functools.partial( train_utils.pmap_train_step, loss_and_metrics_fn=loss_and_metrics_fn, axis_name="batch", sharded_match_fn=sharded_match_fn, gradient_accum_steps=config.gradient_accum_steps) p_train_step = jax.pmap(train_step, axis_name="batch") p_eval_step = jax.pmap(compute_stats, axis_name="batch") eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name) train_stats = [] logging.info("Starting training loop.") logging.info("====================") for step in range(start_step, num_train_steps): with jax.profiler.StepTraceContext("train", step_num=step): train_batch = next(train_iter) train_batch = common_utils.shard(train_batch) state, train_step_stats, rngs = p_train_step(state, train_batch, rng=rngs) train_stats.append(train_step_stats) if ((step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0) or step == num_train_steps - 1): # We allow all hosts to potentially save checkpoints because some model # parameters are sharded across devices. Parameters replicated across # devices (i.e. not sharded) will only be checkpointed by host 0. unreplicated_train_state = jax.tree_map( np.array, core_utils.tree_unreplicate_by_name(state, not_sharded_match_fn)) checkpoints.save_checkpoint(workdir, unreplicated_train_state, sharded_match_fn, step, keep=config.checkpoints_to_keep) del unreplicated_train_state # Only used for checkpointing. # Periodic metric handling. if step % eval_frequency != 0 and step < num_train_steps - 1: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = train_utils.collect_metrics(train_stats) train_summary = train_utils.compute_classification_metrics( train_metrics, is_regression_task) train_summary["learning_rate"] = learning_rate_fn(step) if jax.process_index() == 0: assert train_summary_writer for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next training evaluation cycle. train_stats = [] logging.info("Gathering validation metrics at step: %d", step) for split_suffix in split_suffixes: eval_ds = classification_inputs( split=tfds.Split.VALIDATION + split_suffix, batch_size=per_host_eval_batch_size, training=False) eval_stats = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds): eval_stats.append( _evaluate(p_eval_step, state.params, eval_batch)) eval_metrics = {} for k in eval_stats[ 0]: # All batches of output stats are the same size eval_metrics[k] = np.concatenate( [stat[k] for stat in eval_stats], axis=0) eval_summary = eval_metrics_fn(eval_metrics) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(f"{key}{split_suffix}", val, step) eval_summary_writer.flush()