def f(log_probs, advantages, old_log_probs, mask): if reweight: # Use new policy weights for sampled actions instead. mask *= jnp.exp(fastmath.stop_gradient(log_probs) - old_log_probs) if sampled_all_discrete: # Actions were sampled uniformly; weight them. mask *= jnp.exp(old_log_probs) weights = jnp.minimum(awr_weights(advantages, beta), w_max) return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
def f(preds, values, returns, actions, mask): advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1) logps = self._policy_dist.log_prob(preds, actions) awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)( (logps, advantages, jnp.zeros_like(logps), mask)) l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff return awr_loss + l2_value_loss
def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, mask): """Definition of the A2C loss.""" del old_log_probs # Typically we have dist_inputs of the shape float32[128,1,18] assert len(dist_inputs.shape) == 3, ( f'dist_inputs.shape was {dist_inputs.shape} ' f'but expected length of the tensor shape is 3') # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] assert values.shape == returns.shape, ( f'values.shape was {values.shape}' f'returns.shape was (returns.shape)') # actions of the shape int32[128,1] in the case of discrete actions # and float32[128,1,6] in the case of of half-cheetah # actions agree with returns/values on the first two coordinates assert actions.shape[0:2] == returns.shape[0:2], ( f'actions.shape was {actions.shape}' f'returns.shape was (returns.shape)') # and mask of the shape float32[128,1] assert len(mask.shape) == 2, f'mask.shape was {mask.shape}' # which agrees with returns/values/actions on the first two coordinates assert mask.shape[0:2] == returns.shape[0:2], ( f'mask.shape was {mask.shape}' f'returns.shape was (returns.shape)') a2c_objective = rl_layers.A2CObjective( dist_inputs, stop_gradient(values), returns, dones, rewards, actions, mask, log_prob_fun=self._policy_dist.log_prob, normalize_advantages=self._normalize_advantages) # we insist that a2c_objective is a scalar assert jnp.ndim( a2c_objective) == 0, f'a2c_objective was {a2c_objective}' entropy_loss = rl_layers.EntropyLoss( dist_inputs, distribution=self._policy_dist, coeff=self._entropy_coeff, ) assert jnp.ndim( entropy_loss) == 0, f'entropy_loss was {entropy_loss}' l2_value_loss = rl_layers.ValueLoss( values, returns, value_loss_coeff=self._value_loss_coeff) assert jnp.ndim( l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}' combined_loss = a2c_objective + l2_value_loss - entropy_loss return combined_loss
def _do_custom_gradients(self, x, weights, state, rng): """Calls this layer for a forward pass, but with custom gradients.""" def _do_forward(y, weights): old_weights, old_state, old_rng = self.weights, self.state, self._rng self.weights = weights res = self.forward(y) s = self.state self.weights, self.state, self._rng = old_weights, old_state, old_rng return res, s def do_forward_vjp(y, weights): """Custom gradient (vjp) function.""" old_weights, old_state, old_rng = self.weights, self.state, self._rng self.weights = weights output = self.forward(y) new_state = self.state self.weights, self.state, self._rng = old_weights, old_state, old_rng def vjpfun(grad): grad = grad[0] # Ignore dummy gradient wrt state. res = self.backward(y, output, grad, weights, state, new_state, rng) return res return (output, new_state), vjpfun do_forward = fastmath.custom_grad(do_forward_vjp, _do_forward) output, state = do_forward(x, weights) # TODO(lukaszkaiser): Investigate why we need this stop_gradient state = fastmath.stop_gradient(state) return output, state
def StopGradient(): """Returns an identity layer with a stop gradient.""" return Fn('StopGradient', lambda x: fastmath.stop_gradient(x)) # pylint: disable=unnecessary-lambda
def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, mask): """Definition of the Proximal Policy Optimization loss.""" del mask # TODO(lukaszkaiser): make PPO work with Transformer # We have dist_inputs of the shape float32[128,1,18] assert len(dist_inputs.shape) == 3, ( f'dist_inputs.shape was {dist_inputs.shape}' f'but expected length of the tensor shape is 3') # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # dones of the shape int32[128,1,1] # rewards of the shape float32[128,1,1] # and old_log_probs of the shape float32[128,1] assert values.shape == returns.shape, ( f'values.shape was {values.shape}' f'returns.shape was {returns.shape}') assert values.shape == dones.shape, ( f'values.shape was {values.shape}' f'returns.shape was {dones.shape}') assert rewards.shape == dones.shape, ( f'values.shape was {values.shape}' f'returns.shape was {dones.shape}') assert returns.shape[0:2] == old_log_probs.shape, ( f'returns.shape was {returns.shape}' f'old_log_probs.shape was {old_log_probs.shape}') # actions is a tensor of the shape int32[128,1] in the case # of discrete actions and float32[128,1,6] in the case of # half-cheetah and other continuous actions # actions agree with returns/values on the first two coordinates # meaning batch and time assert actions.shape[0:2] == returns.shape[0:2], ( f'actions.shape was {actions.shape} and ' f'returns.shape was {returns.shape}') ppo_objective = rl_layers.PPOObjective( dist_inputs, stop_gradient(values), returns, dones, rewards, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages) # we insist that ppo_objective is a vector of shape [128,1] assert len(ppo_objective.shape) == 2, ( f'ppo_objective was {ppo_objective}') # which agrees with returns/values/actions on the first two coordinates assert ppo_objective.shape[0:2] == values.shape[0:2], ( f'ppo_objective.shape was {ppo_objective.shape} and ' f'values.shape was {values.shape}') entropy_loss = rl_layers.EntropyLoss( dist_inputs, distribution=self._policy_dist, coeff=self._entropy_coeff, ) assert jnp.ndim(entropy_loss) == 0, f'entropy_loss was {entropy_loss}' l2_value_loss = rl_layers.ValueLoss( values, returns, value_loss_coeff=self._value_loss_coeff) assert jnp.ndim(l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}' return -ppo_objective.mean() + l2_value_loss - entropy_loss
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input. """ m1, m2, mb, w1, w2, b2 = self.weights if self._mode != 'predict': w1 = jnp.reshape(w1.T, (-1, self._d_ff)) w2 = jnp.reshape(w2, (self._d_ff, -1)) x_shape = x.shape x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: should we add bias and/or put relu after the low-rank m1 dot? mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2]) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. rng1, rng2 = fastmath.random.split(self.rng, 2) u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) g = -jnp.log(-jnp.log(u)) quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = fastmath.stop_gradient(quant_mask) quant_mask += mask - fastmath.stop_gradient( mask) # straight-through # We will sometimes (quant_prob of the batches) use the soft-mask instead # of the quantized mask to improve training stability (see paper above). select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0) quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) if self._mode == 'train': # In training, run full matmul to get benefits from the above tricks. mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 elif self._mode == 'predict': # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1)) # w2 = jnp.reshape(w2, (self._d1, self._d2, -1)) # This implementation mimicks inference. It's not efficient for large # size of joint_batch, but at inference that will be 1 most of the time. # Shapes: # quant_mask is [joint_batch, self._d1] # w1 is [d_model, self._d1, self._d2] # we'll index w1 with advanced numpy indexing, first range over # self._d1 times the batch size, second range being quant_mask batch_size = quant_mask.shape[0] idx1 = jnp.array([jnp.arange(self._d1)] * batch_size) # flatten indices and select from w1 idx1 = jnp.reshape(idx1, [-1]) idx2 = jnp.reshape(quant_mask, [-1]) w = w1[idx1, idx2, :] # now we have per-element weights with batch dim w = jnp.reshape(w, [batch_size, self._d1, -1]) mid = jnp.einsum('ai,aji->aj', x, w) relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [self._d1, self._d2, d_model] v = w2[idx1, idx2, :] v = jnp.reshape(v, [batch_size, self._d1, -1]) res = jnp.einsum('ai,aij->aj', relu, v) + b2 else: quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 return jnp.reshape(res, x_shape) # un-flatten if needed
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input. """ m1, w1, w2, b2 = self.weights x_shape = x.shape x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: check if we need bias and/or put relu after the m1 dot? mask_logits = jnp.dot(x, m1) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. # TODO(lukaszkaiser, chowdhery): Extract this block and share rng1, rng2 = fastmath.random.split(self.rng, 2) u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) g = -jnp.log(-jnp.log(u)) selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = tl.one_hot(selected_experts, self._num_experts) quant_mask = fastmath.stop_gradient(quant_mask) quant_mask += mask - fastmath.stop_gradient( mask) # straight-through # We will sometimes (50% of the batches) use the soft-mask instead of # the quantized mask to improve training stability (see the paper above). # Q: is selecting 50% of batches the best? Other %? Mixed in-batch? select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0) quant_mask = jnp.where(select > 0.0, quant_mask, mask) else: quant_mask = tl.one_hot(selected_experts, self._num_experts) quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1]) quant_mask_shape = quant_mask.shape batch_size = quant_mask.shape[0] if self._mode == 'predict' and batch_size == 1: # This implementation mimicks inference for batch_size 1. start_idx = selected_experts[0] * self._n_elements_in_block # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] w = fastmath.dynamic_slice( w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block]) mid = jnp.dot(x, w) relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] v = fastmath.dynamic_slice( w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]]) v = jnp.reshape(v, [self._n_elements_in_block, -1]) res = jnp.dot(relu, v) + b2 else: expanded_mask = jnp.broadcast_to( quant_mask, (quant_mask_shape[0], quant_mask.shape[1], self._n_elements_in_block)) expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 return jnp.reshape(res, x_shape) # un-flatten if needed