def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: """Computes an SGD step, returning new state and metrics for logging.""" # Compute gradients. grad_fn = jax.value_and_grad(loss_fn) loss_value, gradients = grad_fn(state.params, sample) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) metrics = { 'loss': loss_value, 'param_norm': optax.global_norm(new_params), 'param_updates_norm': optax.global_norm(updates), } new_state = TrainingState(params=new_params, opt_state=new_opt_state) return new_state, metrics
def model_update_minibatch( carry: Tuple[networks_lib.Params, optax.OptState], minibatch: Batch, ) -> Tuple[Tuple[networks_lib.Params, optax.OptState], Dict[ str, jnp.ndarray]]: """Performs model update for a single minibatch.""" params, opt_state = carry # Normalize advantages at the minibatch level before using them. advantages = ((minibatch.advantages - jnp.mean(minibatch.advantages, axis=0)) / (jnp.std(minibatch.advantages, axis=0) + 1e-8)) gradients, metrics = grad_fn(params, minibatch.observations, minibatch.actions, minibatch.behavior_log_probs, minibatch.target_values, advantages, minibatch.behavior_values) # Apply updates updates, opt_state = optimizer.update(gradients, opt_state) params = optax.apply_updates(params, updates) metrics['norm_grad'] = optax.global_norm(gradients) metrics['norm_updates'] = optax.global_norm(updates) return (params, opt_state), metrics
def _update_func( self, params: hk.Params, ema_params: hk.Params, network_state: hk.State, ema_network_state: hk.State, opt_state: optax.OptState, global_step: jnp.ndarray, rng: jnp.ndarray, batch: datasets.Batch, ) -> Tuple[hk.Params, hk.Params, hk.State, hk.State, optax.OptState, losses.LogsDict]: """Updates parameters.""" grad_fn = jax.value_and_grad(self._loss, has_aux=True) (_, (stats, new_network_state)), grads = grad_fn( params, ema_params, network_state, ema_network_state, rng, batch) learning_rate = self._get_learning_rate(global_step) _, opt_apply = self._optimizer(learning_rate) grad = jax.lax.pmean(grads, axis_name='i') updates, opt_state = opt_apply(grad, opt_state, params) params = optax.apply_updates(params, updates) # Stats and logging. param_norm = optax.global_norm(params) grad_norm = optax.global_norm(grad) ema_rate = schedules.ema_decay_schedule( step=global_step, **self.config.eval.ema_annealing_schedule) num_non_padded_nodes = ( batch.graph.n_node.sum() - jraph.get_number_of_padding_with_graphs_nodes(batch.graph)) num_non_padded_edges = ( batch.graph.n_edge.sum() - jraph.get_number_of_padding_with_graphs_edges(batch.graph)) num_non_padded_graphs = ( batch.graph.n_node.shape[0] - jraph.get_number_of_padding_with_graphs_graphs(batch.graph)) avg_num_nodes = num_non_padded_nodes / num_non_padded_graphs avg_num_edges = num_non_padded_edges / num_non_padded_graphs stats.update( dict( global_step=global_step, grad_norm=grad_norm, param_norm=param_norm, learning_rate=learning_rate, ema_rate=ema_rate, avg_num_nodes=avg_num_nodes, avg_num_edges=avg_num_edges, )) ema_fn = ( lambda x, y: # pylint:disable=g-long-lambda schedules.apply_ema_decay(x, y, ema_rate)) ema_params = jax.tree_multimap(ema_fn, ema_params, params) ema_network_state = jax.tree_multimap( ema_fn, ema_network_state, network_state, ) return (params, ema_params, new_network_state, ema_network_state, opt_state, stats)
def _step(self, particles, optimizer_state, params): """ Updates particles in the direction given by self.gradient Arguments: particles: jnp.ndarray of shape (n, d) params: can be anything. e.g. inducing particles in the case of SVGD, deep NN params for learned f, or None. Returns: particles (updated) optimizer_state (updated) grad_aux: dict containing auxdata """ grads, grad_aux = self.gradient(params, particles, aux=True) updated_grads, optimizer_state = self.opt.update( grads, optimizer_state, particles) particles = optax.apply_updates(particles, updated_grads) grad_aux.update({ "global_grad_norm": optax.global_norm(grads), "global_grad_norm_post_update": optax.global_norm(updated_grads), }) grad_aux.update({}) # grad_aux.update({"grads": updated_grads}) return particles, optimizer_state, grad_aux
def loss_fn(self, params, dlogp: np.ndarray, key: np.ndarray, particles: np.ndarray, dropout: bool = False): """ Arguments: params: neural net paramers dlogp: gradient grad(log p)(x), shaped (n, d) key: random PRNGKey particles: array of shape (n, d) dropout: whether to use dropout in the gradient network """ n, d = particles.shape v = self.get_field(particles, params, dropout=dropout) if dropout: f = utils.negative(v) else: def f(x, dummy_key): return -v(x) # stein discrepancy def h(x, dlogp_x, key): zkey, fkey = random.split(key) z = random.normal(zkey, (d, )) zdf = grad(lambda _x: np.vdot(z, f(_x, fkey))) div_f = np.vdot(zdf(x), z) #div_f = np.trace(jacfwd(f)(x, fkey)) sd = np.vdot(f(x, fkey), dlogp_x) + div_f l2 = np.vdot(f(x, fkey), f(x, fkey)) aux = { "sd": sd, "l2": l2, } return -sd + l2 * self.lambda_reg, aux keys = random.split(key, n) loss, aux = vmap(h)(particles, dlogp, keys) loss = loss.mean() aux = {k: v.mean() for k, v in aux.items()} fnorm = optax.global_norm(jnp.mean(vmap(f)(particles, keys), axis=0)) pnorm = optax.global_norm(jnp.mean(dlogp, axis=0)) aux.update({ "loss": loss, "l1_diff": fnorm - pnorm, "l1_ratio": fnorm / pnorm }) # # add L1 term # if self.l1_weight: # loss = loss + self.l1_weight * np.abs(jnp.mean(vmap(f)(particles) - dlogp)) return loss, aux
def _create_loss_metrics( loss_has_aux: bool, loss_result: Union[jnp.ndarray, Tuple[jnp.ndarray, loggers.LoggingData]], gradients: jnp.ndarray, ): """Creates loss metrics for logging.""" # Validate input. if loss_has_aux and not (len(loss_result) == 2 and isinstance( loss_result[0], jnp.ndarray) and isinstance(loss_result[1], dict)): raise ValueError( 'Could not parse loss value and metrics from loss_fn\'s ' 'output. Since loss_has_aux is enabled, loss_fn must ' 'return loss_value and auxiliary metrics.') if not loss_has_aux and not isinstance(loss_result, jnp.ndarray): raise ValueError( f'Loss returns type {loss_result}. However, it should ' 'return a jnp.ndarray, given that loss_has_aux = False.') # Maybe unpack loss result. if loss_has_aux: loss, metrics = loss_result else: loss = loss_result metrics = {} # Complete metrics dict and return it. metrics['loss'] = loss metrics['gradient_norm'] = optax.global_norm(gradients) return metrics
def sgd_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: loss_and_grad = jax.value_and_grad(loss_fn, argnums=1) # Compute losses and their gradients. key, key_input = jax.random.split(state.key) loss_value, gradients = loss_and_grad(network.apply, state.policy_params, key_input, transitions) policy_update, optimizer_state = optimizer.update(gradients, state.optimizer_state) policy_params = optax.apply_updates(state.policy_params, policy_update) new_state = TrainingState( optimizer_state=optimizer_state, policy_params=policy_params, key=key, steps=state.steps + 1, ) metrics = { 'loss': loss_value, 'gradient_norm': optax.global_norm(gradients) } return new_state, metrics
def _step(self, key, params, optimizer_state, dlogp, val_dlogp, particles, val_particles): """ update parameters and compute validation loss args: dlogp: array of shape (n_train, d) val_dlogp: array of shape (n_validation, d) """ [loss, loss_aux], grads = value_and_grad(self.loss_fn, has_aux=True)(params, dlogp, key, particles, dropout=self.dropout) grads, optimizer_state = self.opt.update(grads, optimizer_state, params) params = optax.apply_updates(params, grads) _, val_loss_aux = self.loss_fn(params, val_dlogp, key, val_particles, dropout=False) auxdata = {k: v for k, v in loss_aux.items()} auxdata.update({"val_" + k: v for k, v in val_loss_aux.items()}) auxdata.update({ "global_gradient_norm": optax.global_norm(grads), }) return params, optimizer_state, auxdata
def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.DeviceArray]]: """Do a step of SGD.""" grad_fn = jax.value_and_grad(loss) loss_value, gradients = grad_fn(state.params, sample) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) steps = state.steps + 1 new_state = TrainingState( params=new_params, opt_state=new_opt_state, steps=steps) # Compute the global norm of the gradients for logging. global_gradient_norm = optax.global_norm(gradients) fetches = {'loss': loss_value, 'gradient_norm': global_gradient_norm} return new_state, fetches
def _update_func( self, base_rng, state, inputs, ): """Computes loss and updates model parameters.""" step = state.step rng = jax.random.fold_in(base_rng, jax.lax.axis_index('batch')) rng = jax.random.fold_in(rng, step) grad_loss_fn = jax.value_and_grad(self._loss_fn, has_aux=True) (_, loss_dict), scaled_grads = grad_loss_fn(state.params, inputs, rng) grads = jax.lax.psum(scaled_grads, axis_name='batch') grad_norm = optax.global_norm(grads) loss_dict['scalars']['grad_norm'] = grad_norm # Compute and apply updates via our optimizer. learning_rate = self.learning_rate(state.step) loss_dict['scalars']['learning_rate'] = learning_rate _, opt_apply = self.optimizer(learning_rate) updates, new_opt_state = opt_apply(grads, state.opt_state, state.params) new_params = optax.apply_updates(state.params, updates) # Update ema params ema_rate = self.config.evaluation.ema_rate new_ema_params = jax.tree_multimap( lambda x, y: x + (1 - ema_rate) * (y - x), state.ema_params, new_params, ) new_state = state.replace(step=step + 1, params=new_params, ema_params=new_ema_params, opt_state=new_opt_state) # Rescale loss dict and return loss_dict['scalars'] = jax.tree_map( lambda x: jax.lax.psum(x, axis_name='batch') / jax.device_count(), loss_dict['scalars'], ) return new_state, loss_dict
def _update_func( self, params: hk.Params, state: hk.State, opt_state: OptState, inputs: dataset.Batch, rng: jnp.ndarray, global_step: int, ) -> Tuple[hk.Params, hk.State, OptState, Scalars]: """Applies an update to parameters and returns new state.""" # This function computes the gradient of the first output of loss_fn and # passes through the other arguments unchanged. grad_loss_fn = jax.grad(self._loss_fn, has_aux=True) scaled_grads, (loss_scalars, state) = grad_loss_fn(params, state, inputs, rng) grads = jax.lax.psum(scaled_grads, axis_name='i') # Grab the learning rate to log before performing the step. learning_rate = self._lr_schedule(global_step) # Compute and apply updates via our optimizer. updates, opt_state = self._optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) n_params = 0 for k in params.keys(): for l in params[k]: n_params = n_params + np.prod(params[k][l].shape) # Scalars to log (note: we log the mean across all hosts/devices). scalars = { 'learning_rate': learning_rate, 'n_params (M)': float(n_params / 1e6), 'global_gradient_norm': optax.global_norm(grads) } loss_scalars = {f'train_{k}': v for k, v in loss_scalars.items()} scalars.update(loss_scalars) scalars = jax.lax.pmean(scalars, axis_name='i') return params, state, opt_state, scalars
def update(updates, state, params=None): inner_state = state.inner_state # Compute gradient norm and clip gradient if necessary gradient_norm = optax.global_norm(updates) flat_updates = jax.tree_flatten(updates)[0] isfinite = jnp.all( jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) islowerthan = gradient_norm < gradient_norm_skip_threshold def do_update(_): return inner.update(updates, inner_state, params) def reject_update(_): return (jax.tree_map(jnp.zeros_like, updates), inner_state) updates, new_inner_state = jax.lax.cond(jnp.logical_and( isfinite, islowerthan), do_update, reject_update, operand=None) return updates, MaybeSkipGradientUpdateState( inner_state=new_inner_state)
def solve_dual_train( env: Dict[int, DualOp], dual_state: ConfigDict, opt: optax.GradientTransformation, inner_opt: InnerMaxStrategy, dual_params: Params, spec_type: verify_utils.SpecType, dual_params_types: ParamsTypes, logger: Callable[[int, Mapping[str, Any]], None], key: jnp.array, num_steps: int, affine_before_relu: bool, device_type=None, merge_problems: Optional[Dict[int, int]] = None, block_to_time: bool = False, ) -> ConfigDict: """Compute verified upper bound via functional lagrangian relaxation. Args: env: Lagrangian computations for each contributing graph node. dual_state: state of the dual problem. opt: an optimizer for the outer Lagrangian parameters. inner_opt: inner optimization strategy for training. dual_params: dual parameters to be minimized via gradient-based optimization. spec_type: Specification type, adversarial or uncertainty specification. dual_params_types: types of inequality encoded by the corresponding dual_params. logger: logging function. key: jax.random.PRNGKey. num_steps: total number of outer optimization steps. affine_before_relu: whether layer ordering uses the affine layer before the ReLU. device_type: string, used to clamp to a particular hardware device. Default None uses JAX default device placement. merge_problems: the key of the dictionary corresponds to the index of the layer to begin the merge, and the associated value corresponds to the number of consecutive layers to be merged with it. For example, `{0: 2, 2: 3}` will merge together layer 0 and 1, as well as layers 2, 3 and 4. block_to_time: whether to block computations at the end of each iteration to account for asynchronicity dispatch when timing. Returns: dual_state: new state of the dual problem. info: various information for logging / debugging. """ assert device_type in (None, 'cpu', 'gpu'), 'invalid device_type' # create dual functions loss_func = dual_build.build_dual_fun( env=env, lagrangian_form=dual_params_types.lagrangian_form, inner_opt=inner_opt, merge_problems=merge_problems, affine_before_relu=affine_before_relu, spec_type=spec_type) value_and_grad = jax.value_and_grad(loss_func, has_aux=True) def grad_step(params, opt_state, key, step): (loss_val, stats), g = value_and_grad(params, key, step) updates, new_opt_state = opt.update(g, opt_state) new_params = optax.apply_updates(params, updates) return new_params, new_opt_state, loss_val, stats # Some solvers (e.g. MIP) cannot be jitted and run on CPU only if inner_opt.jittable: grad_step = jax.jit(grad_step, backend=device_type) dual_state.step = 0 dual_state.key = key dual_state.opt_state = opt.init(dual_params) dual_state.dual_params = dual_params dual_state.loss = 0.0 dual_state.best_loss = jnp.inf dual_state.best_dual_params = dual_params # optimize the dual (Lagrange) parameters with a gradient-based optimizer while dual_state.step < num_steps: key_step, dual_state.key = jax.random.split(dual_state.key) start_time = time.time() dual_params, dual_state.opt_state, dual_state.loss, stats = grad_step( dual_state.dual_params, dual_state.opt_state, key_step, dual_state.step) dual_params = dual_build.project_dual(dual_params, dual_params_types) if dual_state.loss <= dual_state.best_loss: dual_state.best_loss = dual_state.loss # store value from previous iteration as loss corresponds to those params dual_state.best_dual_params = dual_state.dual_params dual_state.dual_params = dual_params # projected dual params if block_to_time: dual_state.loss.block_until_ready() # asynchronous dispatch stats['time_per_iteration'] = time.time() - start_time stats['best_loss'] = dual_state.best_loss stats['dual_params_norm'] = optax.global_norm(dual_state.dual_params) logger(dual_state.step, stats) dual_state.step += 1 return dual_state
def _parallel_train_step( optimizer, batched_examples, static_batch_metadata, loss_fn, max_global_norm=None, **optimizer_hyper_params, ): """Train the model for one step in parallel across devices. Args: optimizer: Optimizer that tracks the model and parameter state. Should be replicated to each device, i.e. should contain ShardedDeviceArrays with a leading axis (num_devices, ...) but with the same content on each device. batched_examples: A structure of NDArrays representing a batch of examples. Should have two leading batch dimensions: (num_devices, batch_size_per_device, ...) static_batch_metadata: Metadata about this batch, which will be shared across all batched examples. Each value of this results in a separate XLA-compiled module. loss_fn: Task-specific non-batched loss function to apply. Should take the current model (optimizer.target) and an example from batched_examples, and return a tuple of the current loss (as a scalar) and a dictionary from string names to metric values (also scalars, or RatioMetrics). max_global_norm: Maximum global norm to clip gradients to. Should be a scalar, which will be broadcast automatically. **optimizer_hyper_params: Hyperparameters to pass to the optimizer's `apply_gradient` function, which will be broadcast across devices automatically. Returns: Tuple (updated_optimizer, grads_ok, metrics). Metrics will be as returned by loss_fn, with an extra elements "loss". All metrics will be averaged across all elements of the batch. Both optimizer and metrics will contain ShardedDeviceArrays that are identical across devices. grads_ok will be a replicated bool ndarray that is True if the gradients were finite. """ def batched_loss_fn(model): """Apply loss function across a batch of examples.""" loss, metrics = jax.vmap(loss_fn, (None, 0, None))(model, batched_examples, static_batch_metadata) return jnp.mean(loss), metrics # Compute gradients of loss, along with metrics. (loss, metrics), grads = jax.value_and_grad(batched_loss_fn, has_aux=True)(optimizer.target) metrics["loss"] = loss # Exchange average gradients and metrics across devices. agg_grads = jax.lax.pmean(grads, "devices") agg_metrics = {} for k, v in metrics.items(): if isinstance(v, RatioMetric): num = jax.lax.psum(jnp.sum(v.numerator), "devices") denom = jax.lax.psum(jnp.sum(v.denominator), "devices") new_value = num / denom else: # Use nanmean to aggregate bare floats. new_value = jnp.nanmean(jax.lax.all_gather(v, "devices")) agg_metrics[k] = new_value # Compute global norm and possibly clip. global_norm = optax.global_norm(agg_grads) agg_metrics["gradient_global_norm"] = global_norm if max_global_norm is not None: should_clip = global_norm > max_global_norm agg_grads = jax.tree_map( lambda g: jnp.where(should_clip, g * max_global_norm / global_norm, g), agg_grads) agg_metrics["gradient_was_clipped"] = should_clip.astype("float32") # Check for non-finite gradients. grads_ok = jnp.all( jnp.stack( [jnp.all(jnp.isfinite(x)) for x in jax.tree_leaves(agg_grads)])) # Apply updates. updated_optimizer = optimizer.apply_gradient(agg_grads, **optimizer_hyper_params) return updated_optimizer, grads_ok, agg_metrics, agg_grads
def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: """Performs a minibatch SGD step, returning new state and metrics.""" # Extract the data. data = sample.data # TODO(sinopalnikov): replace it with namedtuple unpacking observations, actions, rewards, termination, extra = ( data.observation, data.action, data.reward, data.discount, data.extras) discounts = termination * discount behavior_log_probs = extra['log_prob'] def get_behavior_values( params: networks_lib.Params, observations: types.NestedArray) -> jnp.ndarray: o = jax.tree_map( lambda x: jnp.reshape(x, [-1] + list(x.shape[2:])), observations) _, behavior_values = ppo_networks.network.apply(params, o) behavior_values = jnp.reshape(behavior_values, rewards.shape[0:2]) return behavior_values behavior_values = get_behavior_values(state.params, observations) # Vmap over batch dimension batch_gae_advantages = jax.vmap(gae_advantages, in_axes=0) advantages, target_values = batch_gae_advantages( rewards, discounts, behavior_values) # Exclude the last step - it was only used for bootstrapping. # The shape is [num_sequences, num_steps, ..] observations, actions, behavior_log_probs, behavior_values = jax.tree_map( lambda x: x[:, :-1], (observations, actions, behavior_log_probs, behavior_values)) trajectories = Batch(observations=observations, actions=actions, advantages=advantages, behavior_log_probs=behavior_log_probs, target_values=target_values, behavior_values=behavior_values) # Concatenate all trajectories. Reshape from [num_sequences, num_steps,..] # to [num_sequences * num_steps,..] assert len(target_values.shape) > 1 num_sequences = target_values.shape[0] num_steps = target_values.shape[1] batch_size = num_sequences * num_steps assert batch_size % num_minibatches == 0, ( 'Num minibatches must divide batch size. Got batch_size={}' ' num_minibatches={}.').format(batch_size, num_minibatches) batch = jax.tree_map( lambda x: x.reshape((batch_size, ) + x.shape[2:]), trajectories) # Compute gradients. grad_fn = jax.grad(loss, has_aux=True) def model_update_minibatch( carry: Tuple[networks_lib.Params, optax.OptState], minibatch: Batch, ) -> Tuple[Tuple[networks_lib.Params, optax.OptState], Dict[ str, jnp.ndarray]]: """Performs model update for a single minibatch.""" params, opt_state = carry # Normalize advantages at the minibatch level before using them. advantages = ((minibatch.advantages - jnp.mean(minibatch.advantages, axis=0)) / (jnp.std(minibatch.advantages, axis=0) + 1e-8)) gradients, metrics = grad_fn(params, minibatch.observations, minibatch.actions, minibatch.behavior_log_probs, minibatch.target_values, advantages, minibatch.behavior_values) # Apply updates updates, opt_state = optimizer.update(gradients, opt_state) params = optax.apply_updates(params, updates) metrics['norm_grad'] = optax.global_norm(gradients) metrics['norm_updates'] = optax.global_norm(updates) return (params, opt_state), metrics def model_update_epoch( carry: Tuple[jnp.ndarray, networks_lib.Params, optax.OptState, Batch], unused_t: Tuple[()] ) -> Tuple[Tuple[jnp.ndarray, networks_lib.Params, optax.OptState, Batch], Dict[str, jnp.ndarray]]: """Performs model updates based on one epoch of data.""" key, params, opt_state, batch = carry key, subkey = jax.random.split(key) permutation = jax.random.permutation(subkey, batch_size) shuffled_batch = jax.tree_map( lambda x: jnp.take(x, permutation, axis=0), batch) minibatches = jax.tree_map( lambda x: jnp.reshape(x, [num_minibatches, -1] + list( x.shape[1:])), shuffled_batch) (params, opt_state), metrics = jax.lax.scan(model_update_minibatch, (params, opt_state), minibatches, length=num_minibatches) return (key, params, opt_state, batch), metrics params = state.params opt_state = state.opt_state # Repeat training for the given number of epoch, taking a random # permutation for every epoch. (key, params, opt_state, _), metrics = jax.lax.scan( model_update_epoch, (state.random_key, params, opt_state, batch), (), length=num_epochs) metrics = jax.tree_map(jnp.mean, metrics) metrics['norm_params'] = optax.global_norm(params) metrics['observations_mean'] = jnp.mean( utils.batch_concat(jax.tree_map( lambda x: jnp.abs(jnp.mean(x, axis=(0, 1))), observations), num_batch_dims=0)) metrics['observations_std'] = jnp.mean( utils.batch_concat(jax.tree_map( lambda x: jnp.std(x, axis=(0, 1)), observations), num_batch_dims=0)) metrics['rewards_mean'] = jnp.mean( jnp.abs(jnp.mean(rewards, axis=(0, 1)))) metrics['rewards_std'] = jnp.std(rewards, axis=(0, 1)) new_state = TrainingState(params=params, opt_state=opt_state, random_key=key) return new_state, metrics