def __init__(self, obs_space, action_space, config): if get_default_config: config = dict(get_default_config(), **config) self.config = config if before_init: before_init(self, obs_space, action_space, config) if make_model_and_action_dist: self.model, self.dist_class = make_model_and_action_dist( self, obs_space, action_space, config) else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], torch=True) self.model = ModelCatalog.get_model_v2(obs_space, action_space, logit_dim, self.config["model"], framework="torch") TorchPolicy.__init__(self, obs_space, action_space, config, self.model, loss_fn, self.dist_class) if after_init: after_init(self, obs_space, action_space, config)
def __init__(self, obs_space, action_space, config): if get_default_config: config = dict(get_default_config(), **config) self.config = config if before_init: before_init(self, obs_space, action_space, config) if make_model_and_action_dist: self.model, self.dist_class = make_model_and_action_dist( self, obs_space, action_space, config) # Make sure, we passed in a correct Model factory. assert isinstance(self.model, TorchModelV2), \ "ERROR: TorchPolicy::make_model_and_action_dist must " \ "return a TorchModelV2 object!" else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") self.model = ModelCatalog.get_model_v2(obs_space, action_space, logit_dim, self.config["model"], framework="torch") TorchPolicy.__init__(self, obs_space, action_space, config, self.model, loss_fn, self.dist_class) if after_init: after_init(self, obs_space, action_space, config)
def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config) setup_config(self, observation_space, action_space, config) TorchPolicy.__init__( self, observation_space, action_space, config, max_seq_len=config["model"]["max_seq_len"], ) ValueNetworkMixin.__init__(self, config) EntropyCoeffSchedule.__init__(self, config["entropy_coeff"], config["entropy_coeff_schedule"]) LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) # The current KL value (as python float). self.kl_coeff = self.config["kl_coeff"] # Constant target value. self.kl_target = self.config["kl_target"] # TODO: Don't require users to call this manually. self._initialize_loss_from_dummy_batch()
def __init__(self, obs_space, action_space, config): if get_default_config: config = dict(get_default_config(), **config) self.config = config if validate_spaces: validate_spaces(self, obs_space, action_space, self.config) if before_init: before_init(self, obs_space, action_space, self.config) # Model is customized (use default action dist class). if make_model: assert make_model_and_action_dist is None, \ "Either `make_model` or `make_model_and_action_dist`" \ " must be None!" self.model = make_model(self, obs_space, action_space, config) dist_class, _ = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") # Model and action dist class are customized. elif make_model_and_action_dist: self.model, dist_class = make_model_and_action_dist( self, obs_space, action_space, config) # Use default model and default action dist. else: dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") self.model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=logit_dim, model_config=self.config["model"], framework="torch") # Make sure, we passed in a correct Model factory. assert isinstance(self.model, TorchModelV2), \ "ERROR: Generated Model must be a TorchModelV2 object!" TorchPolicy.__init__( self, observation_space=obs_space, action_space=action_space, config=config, model=self.model, loss=loss_fn, action_distribution_class=dist_class, action_sampler_fn=action_sampler_fn, action_distribution_fn=action_distribution_fn, max_seq_len=config["model"]["max_seq_len"], get_batch_divisibility_req=get_batch_divisibility_req, ) if callable(training_view_requirements_fn): self.training_view_requirements.update( training_view_requirements_fn(self)) if after_init: after_init(self, obs_space, action_space, config)
def extra_action_out(self, input_dict, state_batches, model_out): if extra_action_out_fn: return extra_action_out_fn(self, input_dict, state_batches, model_out) else: return TorchPolicy.extra_action_out(self, input_dict, state_batches, model_out)
def extra_grad_info(self, train_batch): with torch.no_grad(): if stats_fn: stats_dict = stats_fn(self, train_batch) else: stats_dict = TorchPolicy.extra_grad_info(self, train_batch) return convert_to_non_torch_type(stats_dict)
def extra_compute_grad_fetches(self): if extra_learn_fetches_fn: fetches = convert_to_non_torch_type( extra_learn_fetches_fn(self)) # Auto-add empty learner stats dict if needed. return dict({LEARNER_STATS_KEY: {}}, **fetches) else: return TorchPolicy.extra_compute_grad_fetches(self)
def optimizer(self): if optimizer_fn: optimizers = optimizer_fn(self, self.config) else: optimizers = TorchPolicy.optimizer(self) optimizers = force_list(optimizers) if getattr(self, "exploration", None): optimizers = self.exploration.get_exploration_optimizer( optimizers) return optimizers
def extra_action_out(self, input_dict, state_batches, model, action_dist): with torch.no_grad(): if extra_action_out_fn: stats_dict = extra_action_out_fn( self, input_dict, state_batches, model, action_dist) else: stats_dict = TorchPolicy.extra_action_out( self, input_dict, state_batches, model, action_dist) return convert_to_non_torch_type(stats_dict)
def extra_grad_process(self, optimizer, loss): """Called after optimizer.zero_grad() and loss.backward() calls. Allows for gradient processing before optimizer.step() is called. E.g. for gradient clipping. """ if extra_grad_process_fn: return extra_grad_process_fn(self, optimizer, loss) else: return TorchPolicy.extra_grad_process(self, optimizer, loss)
def __init__(self, obs_space, action_space, config): if get_default_config: config = dict(get_default_config(), **config) self.config = config if before_init: before_init(self, obs_space, action_space, config) if make_model_and_action_dist: self.model, dist_class = make_model_and_action_dist( self, obs_space, action_space, config) # Make sure, we passed in a correct Model factory. assert isinstance(self.model, TorchModelV2), \ "ERROR: TorchPolicy::make_model_and_action_dist must " \ "return a TorchModelV2 object!" else: dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") self.model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=logit_dim, model_config=self.config["model"], framework="torch", **self.config["model"].get("custom_options", {})) TorchPolicy.__init__( self, obs_space, action_space, config, model=self.model, loss=loss_fn, action_distribution_class=dist_class, action_sampler_fn=action_sampler_fn, action_distribution_fn=action_distribution_fn, max_seq_len=config["model"]["max_seq_len"], get_batch_divisibility_req=get_batch_divisibility_req, ) if after_init: after_init(self, obs_space, action_space, config)
def optimizer(self): if optimizer_fn: optimizers = optimizer_fn(self, self.config) else: optimizers = TorchPolicy.optimizer(self) optimizers = force_list(optimizers) if hasattr(self, "exploration"): exploration_optimizers = force_list( self.exploration.get_exploration_optimizer(self.config)) optimizers.extend(exploration_optimizers) return optimizers
def set_weights(self, weights): # Makes sure that whenever we restore weights for this policy's # model, we sync the target network (from the main model) # at the same time. TorchPolicy.set_weights(self, weights) self.update_target()
def extra_grad_info(self, batch_tensors): if stats_fn: return stats_fn(self, batch_tensors) else: return TorchPolicy.extra_grad_info(self, batch_tensors)
def optimizer(self): if optimizer_fn: return optimizer_fn(self, self.config) else: return TorchPolicy.optimizer(self)
def extra_grad_process(self): if extra_grad_process_fn: return extra_grad_process_fn(self) else: return TorchPolicy.extra_grad_process(self)
def extra_grad_info(self, train_batch): if stats_fn: return stats_fn(self, train_batch) else: return TorchPolicy.extra_grad_info(self, train_batch)
def apply_gradients(self, gradients): if apply_gradients_fn: apply_gradients_fn(self, gradients) else: TorchPolicy.apply_gradients(self, gradients)
def extra_action_out(self, model_out): if extra_action_out_fn: return extra_action_out_fn(self, model_out) else: return TorchPolicy.extra_action_out(self, model_out)
def __init__(self, obs_space, action_space, config): if get_default_config: config = dict(get_default_config(), **config) self.config = config if validate_spaces: validate_spaces(self, obs_space, action_space, self.config) if before_init: before_init(self, obs_space, action_space, self.config) # Model is customized (use default action dist class). if make_model: assert make_model_and_action_dist is None, \ "Either `make_model` or `make_model_and_action_dist`" \ " must be None!" self.model = make_model(self, obs_space, action_space, config) dist_class, _ = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") # Model and action dist class are customized. elif make_model_and_action_dist: self.model, dist_class = make_model_and_action_dist( self, obs_space, action_space, config) # Use default model and default action dist. else: dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"], framework="torch") self.model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=logit_dim, model_config=self.config["model"], framework="torch") # Make sure, we passed in a correct Model factory. assert isinstance(self.model, TorchModelV2), \ "ERROR: Generated Model must be a TorchModelV2 object!" TorchPolicy.__init__( self, observation_space=obs_space, action_space=action_space, config=config, model=self.model, loss=loss_fn, action_distribution_class=dist_class, action_sampler_fn=action_sampler_fn, action_distribution_fn=action_distribution_fn, max_seq_len=config["model"]["max_seq_len"], get_batch_divisibility_req=get_batch_divisibility_req, ) # Update this Policy's ViewRequirements (if function given). if callable(view_requirements_fn): self.view_requirements.update(view_requirements_fn(self)) # Merge Model's view requirements into Policy's. self.view_requirements.update( self.model.inference_view_requirements) _before_loss_init = before_loss_init or after_init if _before_loss_init: _before_loss_init(self, self.observation_space, self.action_space, config) # Perform test runs through postprocessing- and loss functions. self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True, stats_fn=stats_fn, ) if _after_loss_init: _after_loss_init(self, obs_space, action_space, config) # Got to reset global_timestep again after this fake run-through. self.global_timestep = 0