def _postprocess_helper_tf(self, obs, next_obs, actions): with ( tf.GradientTape() if self.framework != "tf" else NullContextManager() ) as tape: # Push both observations through feature net to get both phis. phis, _ = self.model._curiosity_feature_net( {SampleBatch.OBS: tf.concat([obs, next_obs], axis=0)} ) phi, next_phi = tf.split(phis, 2) # Predict next phi with forward model. predicted_next_phi = self.model._curiosity_forward_fcnet( tf.concat([phi, tf_one_hot(actions, self.action_space)], axis=-1) ) # Forward loss term (predicted phi', given phi and action vs # actually observed phi'). forward_l2_norm_sqared = 0.5 * tf.reduce_sum( tf.square(predicted_next_phi - next_phi), axis=-1 ) forward_loss = tf.reduce_mean(forward_l2_norm_sqared) # Inverse loss term (prediced action that led from phi to phi' vs # actual action taken). phi_cat_next_phi = tf.concat([phi, next_phi], axis=-1) dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi) action_dist = ( Categorical(dist_inputs, self.model) if isinstance(self.action_space, Discrete) else MultiCategorical(dist_inputs, self.model, self.action_space.nvec) ) # Neg log(p); p=probability of observed action given the inverse-NN # predicted action distribution. inverse_loss = -action_dist.logp(tf.convert_to_tensor(actions)) inverse_loss = tf.reduce_mean(inverse_loss) # Calculate the ICM loss. loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss # Step the optimizer. if self.framework != "tf": grads = tape.gradient(loss, self._optimizer_var_list) grads_and_vars = [ (g, v) for g, v in zip(grads, self._optimizer_var_list) if g is not None ] update_op = self._optimizer.apply_gradients(grads_and_vars) else: update_op = self._optimizer.minimize( loss, var_list=self._optimizer_var_list ) # Return the squared l2 norm and the optimizer update op. return forward_l2_norm_sqared, update_op
def _worker(shard_idx, model, sample_batch, device): torch.set_grad_enabled(grad_enabled) try: with NullContextManager( ) if device.type == "cpu" else torch.cuda.device( # noqa: E501 device): loss_out = force_list( self.loss(model, self.dist_class, sample_batch)) # Call Model's custom-loss with Policy loss outputs and # train_batch. loss_out = model.custom_loss(loss_out, sample_batch) assert len(loss_out) == len(self._optimizers) # Loop through all optimizers. grad_info = {"allreduce_latency": 0.0} parameters = list(model.parameters()) all_grads = [None for _ in range(len(parameters))] for opt_idx, opt in enumerate(self._optimizers): # Erase gradients in all vars of the tower that this # optimizer would affect. param_indices = self.multi_gpu_param_groups[opt_idx] for param_idx, param in enumerate(parameters): if param_idx in param_indices and param.grad is not None: param.grad.data.zero_() # Recompute gradients of loss over all variables. loss_out[opt_idx].backward(retain_graph=True) grad_info.update( self.extra_grad_process(opt, loss_out[opt_idx])) grads = [] # Note that return values are just references; # Calling zero_grad would modify the values. for param_idx, param in enumerate(parameters): if param_idx in param_indices: if param.grad is not None: grads.append(param.grad) all_grads[param_idx] = param.grad if self.distributed_world_size: start = time.time() if torch.cuda.is_available(): # Sadly, allreduce_coalesced does not work with # CUDA yet. for g in grads: torch.distributed.all_reduce( g, op=torch.distributed.ReduceOp.SUM) else: torch.distributed.all_reduce_coalesced( grads, op=torch.distributed.ReduceOp.SUM) for param_group in opt.param_groups: for p in param_group["params"]: if p.grad is not None: p.grad /= self.distributed_world_size grad_info["allreduce_latency"] += time.time( ) - start with lock: results[shard_idx] = (all_grads, grad_info) except Exception as e: import traceback with lock: results[shard_idx] = ( ValueError(e.args[0] + "\n traceback" + traceback.format_exc() + "\n" + "In tower {} on device {}".format( shard_idx, device)), e, )
def _no_grad_context(self): if self.framework == "torch": return torch.no_grad() return NullContextManager()
def context(self) -> contextlib.AbstractContextManager: """Returns a contextmanager for the current forward pass.""" return NullContextManager()