예제 #1
0
    def __init__(self, output, sess=None, input_variables=None):
        """Creates TensorFlowVariables containing extracted variables.

        The variables are extracted by performing a BFS search on the
        dependency graph with loss as the root node. After the tree is
        traversed and those variables are collected, we append input_variables
        to the collected variables. For each variable in the list, the
        variable has a placeholder and assignment operation created for it.

        Args:
            output (tf.Operation, List[tf.Operation]): The tensorflow
                operation to extract all variables from.
            sess (Optional[tf.Session]): Optional tf.Session used for running
                the get and set methods in tf graph mode.
                Use None for tf eager.
            input_variables (List[tf.Variables]): Variables to include in the
                list.
        """
        self.sess = sess
        output = force_list(output)
        queue = deque(output)
        variable_names = []
        explored_inputs = set(output)

        # We do a BFS on the dependency graph of the input function to find
        # the variables.
        while len(queue) != 0:
            tf_obj = queue.popleft()
            if tf_obj is None:
                continue
            # The object put into the queue is not necessarily an operation,
            # so we want the op attribute to get the operation underlying the
            # object. Only operations contain the inputs that we can explore.
            if hasattr(tf_obj, "op"):
                tf_obj = tf_obj.op
            for input_op in tf_obj.inputs:
                if input_op not in explored_inputs:
                    queue.append(input_op)
                    explored_inputs.add(input_op)
            # Tensorflow control inputs can be circular, so we keep track of
            # explored operations.
            for control in tf_obj.control_inputs:
                if control not in explored_inputs:
                    queue.append(control)
                    explored_inputs.add(control)
            if "Variable" in tf_obj.node_def.op or "VarHandle" in tf_obj.node_def.op:
                variable_names.append(tf_obj.node_def.name)
        self.variables = OrderedDict()
        variable_list = [
            v for v in tf1.global_variables() if v.op.node_def.name in variable_names
        ]
        if input_variables is not None:
            variable_list += input_variables

        if not tf1.executing_eagerly():
            for v in variable_list:
                self.variables[v.op.node_def.name] = v

            self.placeholders = {}
            self.assignment_nodes = {}

            # Create new placeholders to put in custom weights.
            for k, var in self.variables.items():
                self.placeholders[k] = tf1.placeholder(
                    var.value().dtype,
                    var.get_shape().as_list(),
                    name="Placeholder_" + k,
                )
                self.assignment_nodes[k] = var.assign(self.placeholders[k])
        else:
            for v in variable_list:
                self.variables[v.name] = v
예제 #2
0
    def __init__(self,
                 observation_space: gym.spaces.Space,
                 action_space: gym.spaces.Space,
                 config: TrainerConfigDict,
                 *,
                 model: ModelV2,
                 loss: Callable[[Policy, ModelV2, type, SampleBatch],
                                TensorType],
                 action_distribution_class: TorchDistributionWrapper,
                 action_sampler_fn: Callable[[TensorType, List[TensorType]],
                                             Tuple[TensorType,
                                                   TensorType]] = None,
                 action_distribution_fn: Optional[Callable[
                     [Policy, ModelV2, TensorType, TensorType, TensorType],
                     Tuple[TensorType, type, List[TensorType]]]] = None,
                 max_seq_len: int = 20,
                 get_batch_divisibility_req: Optional[int] = None):
        """Build a policy from policy and loss torch modules.

        Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
        is set. Only single GPU is supported for now.

        Args:
            observation_space (gym.spaces.Space): observation space of the
                policy.
            action_space (gym.spaces.Space): action space of the policy.
            config (TrainerConfigDict): The Policy config dict.
            model (ModelV2): PyTorch policy module. Given observations as
                input, this module must return a list of outputs where the
                first item is action logits, and the rest can be any value.
            loss (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]):
                Function that takes (policy, model, dist_class, train_batch)
                and returns a single scalar loss.
            action_distribution_class (TorchDistributionWrapper): Class for
                a torch action distribution.
            action_sampler_fn (Callable[[TensorType, List[TensorType]],
                Tuple[TensorType, TensorType]]): A callable returning a
                sampled action and its log-likelihood given Policy, ModelV2,
                input_dict, explore, timestep, and is_training.
            action_distribution_fn (Optional[Callable[[Policy, ModelV2,
                Dict[str, TensorType], TensorType, TensorType],
                Tuple[TensorType, type, List[TensorType]]]]): A callable
                returning distribution inputs (parameters), a dist-class to
                generate an action distribution object from, and
                internal-state outputs (or an empty list if not applicable).
                Note: No Exploration hooks have to be called from within
                `action_distribution_fn`. It's should only perform a simple
                forward pass through some model.
                If None, pass inputs through `self.model()` to get distribution
                inputs.
                The callable takes as inputs: Policy, ModelV2, input_dict,
                explore, timestep, is_training.
            max_seq_len (int): Max sequence length for LSTM training.
            get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
                Optional callable that returns the divisibility requirement
                for sample batches given the Policy.
        """
        self.framework = "torch"
        super().__init__(observation_space, action_space, config)
        if torch.cuda.is_available() and ray.get_gpu_ids(as_str=True):
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        self.model = model.to(self.device)
        # Combine view_requirements for Model and Policy.
        self.view_requirements = {
            **self.model.inference_view_requirements(),
            **self.training_view_requirements(),
        }
        self.exploration = self._create_exploration()
        self.unwrapped_model = model  # used to support DistributedDataParallel
        self._loss = loss
        self._optimizers = force_list(self.optimizer())

        self.dist_class = action_distribution_class
        self.action_sampler_fn = action_sampler_fn
        self.action_distribution_fn = action_distribution_fn

        # If set, means we are using distributed allreduce during learning.
        self.distributed_world_size = None

        self.max_seq_len = max_seq_len
        self.batch_divisibility_req = get_batch_divisibility_req(self) if \
            callable(get_batch_divisibility_req) else \
            (get_batch_divisibility_req or 1)
예제 #3
0
파일: torch_policy.py 프로젝트: rlan/ray
    def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            config: TrainerConfigDict,
            *,
            model: ModelV2,
            loss: Callable[[
                Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
            ], Union[TensorType, List[TensorType]]],
            action_distribution_class: Type[TorchDistributionWrapper],
            action_sampler_fn: Optional[Callable[[
                TensorType, List[TensorType]
            ], Tuple[TensorType, TensorType]]] = None,
            action_distribution_fn: Optional[Callable[[
                Policy, ModelV2, TensorType, TensorType, TensorType
            ], Tuple[TensorType, Type[TorchDistributionWrapper], List[
                TensorType]]]] = None,
            max_seq_len: int = 20,
            get_batch_divisibility_req: Optional[Callable[[Policy],
                                                          int]] = None,
    ):
        """Build a policy from policy and loss torch modules.

        Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
        is set. Only single GPU is supported for now.

        Args:
            observation_space (gym.spaces.Space): observation space of the
                policy.
            action_space (gym.spaces.Space): action space of the policy.
            config (TrainerConfigDict): The Policy config dict.
            model (ModelV2): PyTorch policy module. Given observations as
                input, this module must return a list of outputs where the
                first item is action logits, and the rest can be any value.
            loss (Callable[[Policy, ModelV2, Type[TorchDistributionWrapper],
                SampleBatch], Union[TensorType, List[TensorType]]]): Callable
                that returns a single scalar loss or a list of loss terms.
            action_distribution_class (Type[TorchDistributionWrapper]): Class
                for a torch action distribution.
            action_sampler_fn (Callable[[TensorType, List[TensorType]],
                Tuple[TensorType, TensorType]]): A callable returning a
                sampled action and its log-likelihood given Policy, ModelV2,
                input_dict, explore, timestep, and is_training.
            action_distribution_fn (Optional[Callable[[Policy, ModelV2,
                ModelInputDict, TensorType, TensorType],
                Tuple[TensorType, type, List[TensorType]]]]): A callable
                returning distribution inputs (parameters), a dist-class to
                generate an action distribution object from, and
                internal-state outputs (or an empty list if not applicable).
                Note: No Exploration hooks have to be called from within
                `action_distribution_fn`. It's should only perform a simple
                forward pass through some model.
                If None, pass inputs through `self.model()` to get distribution
                inputs.
                The callable takes as inputs: Policy, ModelV2, ModelInputDict,
                explore, timestep, is_training.
            max_seq_len (int): Max sequence length for LSTM training.
            get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
                Optional callable that returns the divisibility requirement
                for sample batches given the Policy.
        """
        self.framework = "torch"
        super().__init__(observation_space, action_space, config)

        # Create multi-GPU model towers, if necessary.
        # - The central main model will be stored under self.model, residing
        #   on self.device.
        # - Each GPU will have a copy of that model under
        #   self.model_gpu_towers, matching the devices in self.devices.
        # - Parallelization is done by splitting the train batch and passing
        #   it through the model copies in parallel, then averaging over the
        #   resulting gradients, applying these averages on the main model and
        #   updating all towers' weights from the main model.
        # - In case of just one device (1 (fake) GPU or 1 CPU), no
        #   parallelization will be done.

        # Get devices to build the graph on.
        worker_idx = self.config.get("worker_index", 0)
        if not config["_fake_gpus"] and \
                ray.worker._mode() == ray.worker.LOCAL_MODE:
            num_gpus = 0
        elif worker_idx == 0:
            num_gpus = config["num_gpus"]
        else:
            num_gpus = config["num_gpus_per_worker"]
        gpu_ids = list(range(torch.cuda.device_count()))

        # Place on one or more CPU(s) when either:
        # - Fake GPU mode.
        # - num_gpus=0 (either set by user or we are in local_mode=True).
        # - no GPUs available.
        if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
            logger.info("TorchPolicy (worker={}) running on {}.".format(
                worker_idx
                if worker_idx > 0 else "local", "{} fake-GPUs".format(num_gpus)
                if config["_fake_gpus"] else "CPU"))
            self.device = torch.device("cpu")
            self.devices = [
                self.device for _ in range(int(math.ceil(num_gpus)) or 1)
            ]
            self.model_gpu_towers = [
                model if i == 0 else copy.deepcopy(model)
                for i in range(int(math.ceil(num_gpus)) or 1)
            ]
            if hasattr(self, "target_model"):
                self.target_models = {
                    m: self.target_model
                    for m in self.model_gpu_towers
                }
            self.model = model
        # Place on one or more actual GPU(s), when:
        # - num_gpus > 0 (set by user) AND
        # - local_mode=False AND
        # - actual GPUs available AND
        # - non-fake GPU mode.
        else:
            logger.info("TorchPolicy (worker={}) running on {} GPU(s).".format(
                worker_idx if worker_idx > 0 else "local", num_gpus))
            # We are a remote worker (WORKER_MODE=1):
            # GPUs should be assigned to us by ray.
            if ray.worker._mode() == ray.worker.WORKER_MODE:
                gpu_ids = ray.get_gpu_ids()

            if len(gpu_ids) < num_gpus:
                raise ValueError(
                    "TorchPolicy was not able to find enough GPU IDs! Found "
                    f"{gpu_ids}, but num_gpus={num_gpus}.")

            self.devices = [
                torch.device("cuda:{}".format(i))
                for i, id_ in enumerate(gpu_ids) if i < num_gpus
            ]
            self.device = self.devices[0]
            ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
            self.model_gpu_towers = []
            for i, _ in enumerate(ids):
                model_copy = copy.deepcopy(model)
                self.model_gpu_towers.append(model_copy.to(self.devices[i]))
            if hasattr(self, "target_model"):
                self.target_models = {
                    m: copy.deepcopy(self.target_model).to(self.devices[i])
                    for i, m in enumerate(self.model_gpu_towers)
                }
            self.model = self.model_gpu_towers[0]

        # Lock used for locking some methods on the object-level.
        # This prevents possible race conditions when calling the model
        # first, then its value function (e.g. in a loss function), in
        # between of which another model call is made (e.g. to compute an
        # action).
        self._lock = threading.RLock()

        self._state_inputs = self.model.get_initial_state()
        self._is_recurrent = len(self._state_inputs) > 0
        # Auto-update model's inference view requirements, if recurrent.
        self._update_model_view_requirements_from_init_state()
        # Combine view_requirements for Model and Policy.
        self.view_requirements.update(self.model.view_requirements)

        self.exploration = self._create_exploration()
        self.unwrapped_model = model  # used to support DistributedDataParallel
        self._loss = loss
        self._optimizers = force_list(self.optimizer())
        # Store, which params (by index within the model's list of
        # parameters) should be updated per optimizer.
        # Maps optimizer idx to set or param indices.
        self.multi_gpu_param_groups: List[Set[int]] = []
        main_params = {p: i for i, p in enumerate(self.model.parameters())}
        for o in self._optimizers:
            param_indices = []
            for pg_idx, pg in enumerate(o.param_groups):
                for p in pg["params"]:
                    param_indices.append(main_params[p])
            self.multi_gpu_param_groups.append(set(param_indices))

        # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
        # one with m towers (num_gpus).
        num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
        self._loaded_batches = [[] for _ in range(num_buffers)]

        self.dist_class = action_distribution_class
        self.action_sampler_fn = action_sampler_fn
        self.action_distribution_fn = action_distribution_fn

        # If set, means we are using distributed allreduce during learning.
        self.distributed_world_size = None

        self.max_seq_len = max_seq_len
        self.batch_divisibility_req = get_batch_divisibility_req(self) if \
            callable(get_batch_divisibility_req) else \
            (get_batch_divisibility_req or 1)
예제 #4
0
        def __init__(self, observation_space, action_space, config):
            # If this class runs as a @ray.remote actor, eager mode may not
            # have been activated yet.
            if not tf1.executing_eagerly():
                tf1.enable_eager_execution()
            self.framework = config.get("framework", "tfe")
            Policy.__init__(self, observation_space, action_space, config)

            # Log device and worker index.
            from ray.rllib.evaluation.rollout_worker import get_global_worker

            worker = get_global_worker()
            worker_idx = worker.worker_index if worker else 0
            if get_gpu_devices():
                logger.info(
                    "TF-eager Policy (worker={}) running on GPU.".format(
                        worker_idx if worker_idx > 0 else "local"))
            else:
                logger.info(
                    "TF-eager Policy (worker={}) running on CPU.".format(
                        worker_idx if worker_idx > 0 else "local"))

            self._is_training = False

            # Only for `config.eager_tracing=True`: A counter to keep track of
            # how many times an eager-traced method (e.g.
            # `self._compute_actions_helper`) has been re-traced by tensorflow.
            # We will raise an error if more than n re-tracings have been
            # detected, since this would considerably slow down execution.
            # The variable below should only get incremented during the
            # tf.function trace operations, never when calling the already
            # traced function after that.
            self._re_trace_counter = 0

            self._loss_initialized = False
            # To ensure backward compatibility:
            # Old way: If `loss` provided here, use as-is (as a function).
            if loss_fn is not None:
                self._loss = loss_fn
            # New way: Convert the overridden `self.loss` into a plain
            # function, so it can be called the same way as `loss` would
            # be, ensuring backward compatibility.
            elif self.loss.__func__.__qualname__ != "Policy.loss":
                self._loss = self.loss.__func__
            # `loss` not provided nor overridden from Policy -> Set to None.
            else:
                self._loss = None

            self.batch_divisibility_req = (get_batch_divisibility_req(self) if
                                           callable(get_batch_divisibility_req)
                                           else
                                           (get_batch_divisibility_req or 1))
            self._max_seq_len = config["model"]["max_seq_len"]

            if get_default_config:
                config = dict(get_default_config(), **config)

            if validate_spaces:
                validate_spaces(self, observation_space, action_space, config)

            if before_init:
                before_init(self, observation_space, action_space, config)

            self.config = config
            self.dist_class = None
            if action_sampler_fn or action_distribution_fn:
                if not make_model:
                    raise ValueError(
                        "`make_model` is required if `action_sampler_fn` OR "
                        "`action_distribution_fn` is given")
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"])

            if make_model:
                self.model = make_model(self, observation_space, action_space,
                                        config)
            else:
                self.model = ModelCatalog.get_model_v2(
                    observation_space,
                    action_space,
                    logit_dim,
                    config["model"],
                    framework=self.framework,
                )
            # Lock used for locking some methods on the object-level.
            # This prevents possible race conditions when calling the model
            # first, then its value function (e.g. in a loss function), in
            # between of which another model call is made (e.g. to compute an
            # action).
            self._lock = threading.RLock()

            # Auto-update model's inference view requirements, if recurrent.
            self._update_model_view_requirements_from_init_state()

            self.exploration = self._create_exploration()
            self._state_inputs = self.model.get_initial_state()
            self._is_recurrent = len(self._state_inputs) > 0

            # Combine view_requirements for Model and Policy.
            self.view_requirements.update(self.model.view_requirements)

            if before_loss_init:
                before_loss_init(self, observation_space, action_space, config)

            if optimizer_fn:
                optimizers = optimizer_fn(self, config)
            else:
                optimizers = tf.keras.optimizers.Adam(config["lr"])
            optimizers = force_list(optimizers)
            if getattr(self, "exploration", None):
                optimizers = self.exploration.get_exploration_optimizer(
                    optimizers)

            # The list of local (tf) optimizers (one per loss term).
            self._optimizers: List[LocalOptimizer] = optimizers
            # Backward compatibility: A user's policy may only support a single
            # loss term and optimizer (no lists).
            self._optimizer: LocalOptimizer = optimizers[
                0] if optimizers else None

            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=stats_fn,
            )
            self._loss_initialized = True

            if after_init:
                after_init(self, observation_space, action_space, config)

            # Got to reset global_timestep again after fake run-throughs.
            self.global_timestep = 0
예제 #5
0
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: TrainerConfigDict,
        *,
        model: ModelV2,
        loss: Callable[
            [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
            Union[TensorType, List[TensorType]]],
        action_distribution_class: Type[TorchDistributionWrapper],
        action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]],
                                             Tuple[TensorType,
                                                   TensorType]]] = None,
        action_distribution_fn: Optional[
            Callable[[Policy, ModelV2, TensorType, TensorType, TensorType],
                     Tuple[TensorType, Type[TorchDistributionWrapper],
                           List[TensorType]]]] = None,
        max_seq_len: int = 20,
        get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
    ):
        """Build a policy from policy and loss torch modules.

        Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
        is set. Only single GPU is supported for now.

        Args:
            observation_space (gym.spaces.Space): observation space of the
                policy.
            action_space (gym.spaces.Space): action space of the policy.
            config (TrainerConfigDict): The Policy config dict.
            model (ModelV2): PyTorch policy module. Given observations as
                input, this module must return a list of outputs where the
                first item is action logits, and the rest can be any value.
            loss (Callable[[Policy, ModelV2, Type[TorchDistributionWrapper],
                SampleBatch], Union[TensorType, List[TensorType]]]): Callable
                that returns a single scalar loss or a list of loss terms.
            action_distribution_class (Type[TorchDistributionWrapper]): Class
                for a torch action distribution.
            action_sampler_fn (Callable[[TensorType, List[TensorType]],
                Tuple[TensorType, TensorType]]): A callable returning a
                sampled action and its log-likelihood given Policy, ModelV2,
                input_dict, explore, timestep, and is_training.
            action_distribution_fn (Optional[Callable[[Policy, ModelV2,
                Dict[str, TensorType], TensorType, TensorType],
                Tuple[TensorType, type, List[TensorType]]]]): A callable
                returning distribution inputs (parameters), a dist-class to
                generate an action distribution object from, and
                internal-state outputs (or an empty list if not applicable).
                Note: No Exploration hooks have to be called from within
                `action_distribution_fn`. It's should only perform a simple
                forward pass through some model.
                If None, pass inputs through `self.model()` to get distribution
                inputs.
                The callable takes as inputs: Policy, ModelV2, input_dict,
                explore, timestep, is_training.
            max_seq_len (int): Max sequence length for LSTM training.
            get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
                Optional callable that returns the divisibility requirement
                for sample batches given the Policy.
        """
        self.framework = "torch"
        Policy.__init__(self, observation_space, action_space, config)

        counter = ray.get_actor("global_counter")
        ray.get(counter.inc.remote(1))
        count = ray.get(counter.get.remote())
        print(f"{count}********************")
        self.device = xm.xla_device(n=count)  # DIFFERENCE HERE FOR TPU USAGE

        self.model = model.to(self.device)
        # Combine view_requirements for Model and Policy.
        self.training_view_requirements = dict(
            **{
                SampleBatch.ACTIONS:
                ViewRequirement(space=self.action_space, shift=0),
                SampleBatch.REWARDS:
                ViewRequirement(shift=0),
                SampleBatch.DONES:
                ViewRequirement(shift=0),
            }, **self.model.inference_view_requirements)

        self.exploration = self._create_exploration()
        self.unwrapped_model = model  # used to support DistributedDataParallel
        self._loss = loss
        self._optimizers = force_list(self.optimizer())

        self.dist_class = action_distribution_class
        self.action_sampler_fn = action_sampler_fn
        self.action_distribution_fn = action_distribution_fn

        # If set, means we are using distributed allreduce during learning.
        self.distributed_world_size = None

        self.max_seq_len = max_seq_len
        self.batch_divisibility_req = get_batch_divisibility_req(self) if \
            callable(get_batch_divisibility_req) else \
            (get_batch_divisibility_req or 1)
예제 #6
0
        def __init__(self, observation_space, action_space, config):
            assert tf.executing_eagerly()
            self.framework = config.get("framework", "tfe")
            Policy.__init__(self, observation_space, action_space, config)

            # Log device and worker index.
            from ray.rllib.evaluation.rollout_worker import get_global_worker
            worker = get_global_worker()
            worker_idx = worker.worker_index if worker else 0
            if get_gpu_devices():
                logger.info(
                    "TF-eager Policy (worker={}) running on GPU.".format(
                        worker_idx if worker_idx > 0 else "local"))
            else:
                logger.info(
                    "TF-eager Policy (worker={}) running on CPU.".format(
                        worker_idx if worker_idx > 0 else "local"))

            self._is_training = False
            self._loss_initialized = False

            self._loss = loss_fn
            self.batch_divisibility_req = get_batch_divisibility_req(self) if \
                callable(get_batch_divisibility_req) else \
                (get_batch_divisibility_req or 1)
            self._max_seq_len = config["model"]["max_seq_len"]

            if get_default_config:
                config = dict(get_default_config(), **config)

            if validate_spaces:
                validate_spaces(self, observation_space, action_space, config)

            if before_init:
                before_init(self, observation_space, action_space, config)

            self.config = config
            self.dist_class = None
            if action_sampler_fn or action_distribution_fn:
                if not make_model:
                    raise ValueError(
                        "`make_model` is required if `action_sampler_fn` OR "
                        "`action_distribution_fn` is given")
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"])

            if make_model:
                self.model = make_model(self, observation_space, action_space,
                                        config)
            else:
                self.model = ModelCatalog.get_model_v2(
                    observation_space,
                    action_space,
                    logit_dim,
                    config["model"],
                    framework=self.framework,
                )
            # Lock used for locking some methods on the object-level.
            # This prevents possible race conditions when calling the model
            # first, then its value function (e.g. in a loss function), in
            # between of which another model call is made (e.g. to compute an
            # action).
            self._lock = threading.RLock()

            # Auto-update model's inference view requirements, if recurrent.
            self._update_model_view_requirements_from_init_state()

            self.exploration = self._create_exploration()
            self._state_inputs = self.model.get_initial_state()
            self._is_recurrent = len(self._state_inputs) > 0

            # Combine view_requirements for Model and Policy.
            self.view_requirements.update(self.model.view_requirements)

            if before_loss_init:
                before_loss_init(self, observation_space, action_space, config)

            if optimizer_fn:
                optimizers = optimizer_fn(self, config)
            else:
                optimizers = tf.keras.optimizers.Adam(config["lr"])
            optimizers = force_list(optimizers)
            if getattr(self, "exploration", None):
                optimizers = self.exploration.get_exploration_optimizer(
                    optimizers)

            # The list of local (tf) optimizers (one per loss term).
            self._optimizers: List[LocalOptimizer] = optimizers
            # Backward compatibility: A user's policy may only support a single
            # loss term and optimizer (no lists).
            self._optimizer: LocalOptimizer = \
                optimizers[0] if optimizers else None

            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=stats_fn,
            )
            self._loss_initialized = True

            if after_init:
                after_init(self, observation_space, action_space, config)

            # Got to reset global_timestep again after fake run-throughs.
            self.global_timestep = 0
예제 #7
0
        def _worker(shard_idx, model, sample_batch, device):
            torch.set_grad_enabled(grad_enabled)
            try:
                with NullContextManager(
                ) if device.type == "cpu" else torch.cuda.device(  # noqa: E501
                        device):
                    loss_out = force_list(
                        self._loss(self, model, self.dist_class, sample_batch))

                    # Call Model's custom-loss with Policy loss outputs and
                    # train_batch.
                    loss_out = model.custom_loss(loss_out, sample_batch)

                    assert len(loss_out) == len(self._optimizers)

                    # Loop through all optimizers.
                    grad_info = {"allreduce_latency": 0.0}

                    parameters = list(model.parameters())
                    all_grads = [None for _ in range(len(parameters))]
                    for opt_idx, opt in enumerate(self._optimizers):
                        # Erase gradients in all vars of the tower that this
                        # optimizer would affect.
                        param_indices = self.multi_gpu_param_groups[opt_idx]
                        for param_idx, param in enumerate(parameters):
                            if param_idx in param_indices and param.grad is not None:
                                param.grad.data.zero_()
                        # Recompute gradients of loss over all variables.
                        loss_out[opt_idx].backward(retain_graph=True)
                        grad_info.update(
                            self.extra_grad_process(opt, loss_out[opt_idx]))

                        grads = []
                        # Note that return values are just references;
                        # Calling zero_grad would modify the values.
                        for param_idx, param in enumerate(parameters):
                            if param_idx in param_indices:
                                if param.grad is not None:
                                    grads.append(param.grad)
                                all_grads[param_idx] = param.grad

                        if self.distributed_world_size:
                            start = time.time()
                            if torch.cuda.is_available():
                                # Sadly, allreduce_coalesced does not work with
                                # CUDA yet.
                                for g in grads:
                                    torch.distributed.all_reduce(
                                        g, op=torch.distributed.ReduceOp.SUM)
                            else:
                                torch.distributed.all_reduce_coalesced(
                                    grads, op=torch.distributed.ReduceOp.SUM)

                            for param_group in opt.param_groups:
                                for p in param_group["params"]:
                                    if p.grad is not None:
                                        p.grad /= self.distributed_world_size

                            grad_info["allreduce_latency"] += time.time(
                            ) - start

                with lock:
                    results[shard_idx] = (all_grads, grad_info)
            except Exception as e:
                import traceback

                with lock:
                    results[shard_idx] = (
                        ValueError(e.args[0] + "\n traceback" +
                                   traceback.format_exc() + "\n" +
                                   "In tower {} on device {}".format(
                                       shard_idx, device)),
                        e,
                    )
예제 #8
0
파일: tf_policy.py 프로젝트: parasj/ray
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: AlgorithmConfigDict,
        sess: "tf1.Session",
        obs_input: TensorType,
        sampled_action: TensorType,
        loss: Union[TensorType, List[TensorType]],
        loss_inputs: List[Tuple[str, TensorType]],
        model: Optional[ModelV2] = None,
        sampled_action_logp: Optional[TensorType] = None,
        action_input: Optional[TensorType] = None,
        log_likelihood: Optional[TensorType] = None,
        dist_inputs: Optional[TensorType] = None,
        dist_class: Optional[type] = None,
        state_inputs: Optional[List[TensorType]] = None,
        state_outputs: Optional[List[TensorType]] = None,
        prev_action_input: Optional[TensorType] = None,
        prev_reward_input: Optional[TensorType] = None,
        seq_lens: Optional[TensorType] = None,
        max_seq_len: int = 20,
        batch_divisibility_req: int = 1,
        update_ops: List[TensorType] = None,
        explore: Optional[TensorType] = None,
        timestep: Optional[TensorType] = None,
    ):
        """Initializes a Policy object.

        Args:
            observation_space: Observation space of the policy.
            action_space: Action space of the policy.
            config: Policy-specific configuration data.
            sess: The TensorFlow session to use.
            obs_input: Input placeholder for observations, of shape
                [BATCH_SIZE, obs...].
            sampled_action: Tensor for sampling an action, of shape
                [BATCH_SIZE, action...]
            loss: Scalar policy loss output tensor or a list thereof
                (in case there is more than one loss).
            loss_inputs: A (name, placeholder) tuple for each loss input
                argument. Each placeholder name must
                correspond to a SampleBatch column key returned by
                postprocess_trajectory(), and has shape [BATCH_SIZE, data...].
                These keys will be read from postprocessed sample batches and
                fed into the specified placeholders during loss computation.
            model: The optional ModelV2 to use for calculating actions and
                losses. If not None, TFPolicy will provide functionality for
                getting variables, calling the model's custom loss (if
                provided), and importing weights into the model.
            sampled_action_logp: log probability of the sampled action.
            action_input: Input placeholder for actions for
                logp/log-likelihood calculations.
            log_likelihood: Tensor to calculate the log_likelihood (given
                action_input and obs_input).
            dist_class: An optional ActionDistribution class to use for
                generating a dist object from distribution inputs.
            dist_inputs: Tensor to calculate the distribution
                inputs/parameters.
            state_inputs: List of RNN state input Tensors.
            state_outputs: List of RNN state output Tensors.
            prev_action_input: placeholder for previous actions.
            prev_reward_input: placeholder for previous rewards.
            seq_lens: Placeholder for RNN sequence lengths, of shape
                [NUM_SEQUENCES].
                Note that NUM_SEQUENCES << BATCH_SIZE. See
                policy/rnn_sequencing.py for more information.
            max_seq_len: Max sequence length for LSTM training.
            batch_divisibility_req: pad all agent experiences batches to
                multiples of this value. This only has an effect if not using
                a LSTM model.
            update_ops: override the batchnorm update ops
                to run when applying gradients. Otherwise we run all update
                ops found in the current variable scope.
            explore: Placeholder for `explore` parameter into call to
                Exploration.get_exploration_action. Explicitly set this to
                False for not creating any Exploration component.
            timestep: Placeholder for the global sampling timestep.
        """
        self.framework = "tf"
        super().__init__(observation_space, action_space, config)

        # Get devices to build the graph on.
        num_gpus = self._get_num_gpus_for_policy()
        gpu_ids = get_gpu_devices()
        logger.info(f"Found {len(gpu_ids)} visible cuda devices.")

        # Place on one or more CPU(s) when either:
        # - Fake GPU mode.
        # - num_gpus=0 (either set by user or we are in local_mode=True).
        # - no GPUs available.
        if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
            self.devices = [
                "/cpu:0" for _ in range(int(math.ceil(num_gpus)) or 1)
            ]
        # Place on one or more actual GPU(s), when:
        # - num_gpus > 0 (set by user) AND
        # - local_mode=False AND
        # - actual GPUs available AND
        # - non-fake GPU mode.
        else:
            # We are a remote worker (WORKER_MODE=1):
            # GPUs should be assigned to us by ray.
            if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
                gpu_ids = ray.get_gpu_ids()

            if len(gpu_ids) < num_gpus:
                raise ValueError(
                    "TFPolicy was not able to find enough GPU IDs! Found "
                    f"{gpu_ids}, but num_gpus={num_gpus}.")

            self.devices = [
                f"/gpu:{i}" for i, _ in enumerate(gpu_ids) if i < num_gpus
            ]

        # Disable env-info placeholder.
        if SampleBatch.INFOS in self.view_requirements:
            self.view_requirements[
                SampleBatch.INFOS].used_for_compute_actions = False
            self.view_requirements[SampleBatch.INFOS].used_for_training = False
            # Optionally add `infos` to the output dataset
            if self.config["output_config"].get("store_infos", False):
                self.view_requirements[
                    SampleBatch.INFOS].used_for_training = True

        assert model is None or isinstance(model, (ModelV2, tf.keras.Model)), (
            "Model classes for TFPolicy other than `ModelV2|tf.keras.Model` "
            "not allowed! You passed in {}.".format(model))
        self.model = model
        # Auto-update model's inference view requirements, if recurrent.
        if self.model is not None:
            self._update_model_view_requirements_from_init_state()

        # If `explore` is explicitly set to False, don't create an exploration
        # component.
        self.exploration = self._create_exploration(
        ) if explore is not False else None

        self._sess = sess
        self._obs_input = obs_input
        self._prev_action_input = prev_action_input
        self._prev_reward_input = prev_reward_input
        self._sampled_action = sampled_action
        self._is_training = self._get_is_training_placeholder()
        self._is_exploring = (explore if explore is not None else
                              tf1.placeholder_with_default(
                                  True, (), name="is_exploring"))
        self._sampled_action_logp = sampled_action_logp
        self._sampled_action_prob = (tf.math.exp(self._sampled_action_logp)
                                     if self._sampled_action_logp is not None
                                     else None)
        self._action_input = action_input  # For logp calculations.
        self._dist_inputs = dist_inputs
        self.dist_class = dist_class
        self._cached_extra_action_out = None
        self._state_inputs = state_inputs or []
        self._state_outputs = state_outputs or []
        self._seq_lens = seq_lens
        self._max_seq_len = max_seq_len

        if self._state_inputs and self._seq_lens is None:
            raise ValueError(
                "seq_lens tensor must be given if state inputs are defined")

        self._batch_divisibility_req = batch_divisibility_req
        self._update_ops = update_ops
        self._apply_op = None
        self._stats_fetches = {}
        self._timestep = (timestep if timestep is not None else
                          tf1.placeholder_with_default(tf.zeros(
                              (), dtype=tf.int64), (),
                                                       name="timestep"))

        self._optimizers: List[LocalOptimizer] = []
        # Backward compatibility and for some code shared with tf-eager Policy.
        self._optimizer = None

        self._grads_and_vars: Union[ModelGradients, List[ModelGradients]] = []
        self._grads: Union[ModelGradients, List[ModelGradients]] = []
        # Policy tf-variables (weights), whose values to get/set via
        # get_weights/set_weights.
        self._variables = None
        # Local optimizer(s)' tf-variables (e.g. state vars for Adam).
        # Will be stored alongside `self._variables` when checkpointing.
        self._optimizer_variables: Optional[
            ray.experimental.tf_utils.TensorFlowVariables] = None

        # The loss tf-op(s). Number of losses must match number of optimizers.
        self._losses = []
        # Backward compatibility (in case custom child TFPolicies access this
        # property).
        self._loss = None
        # A batch dict passed into loss function as input.
        self._loss_input_dict = {}
        losses = force_list(loss)
        if len(losses) > 0:
            self._initialize_loss(losses, loss_inputs)

        # The log-likelihood calculator op.
        self._log_likelihood = log_likelihood
        if (self._log_likelihood is None and self._dist_inputs is not None
                and self.dist_class is not None):
            self._log_likelihood = self.dist_class(
                self._dist_inputs, self.model).logp(self._action_input)
예제 #9
0
def from_config(cls, config=None, **kwargs):
    """
    Uses the given config to create an object.
    If `config` is a dict, an optional "type" key can be used as a
    "constructor hint" to specify a certain class of the object.
    If `config` is not a dict, `config`'s value is used directly as this
    "constructor hint".

    The rest of `config` (if it's a dict) will be used as kwargs for the
    constructor. Additional keys in **kwargs will always have precedence
    (overwrite keys in `config` (if a dict)).
    Also, if the config-dict or **kwargs contains the special key "_args",
    it will be popped from the dict and used as *args list to be passed
    separately to the constructor.

    The following constructor hints are valid:
    - None: Use `cls` as constructor.
    - An already instantiated object: Will be returned as is; no
        constructor call.
    - A string or an object that is a key in `cls`'s `__type_registry__`
        dict: The value in `__type_registry__` for that key will be used
        as the constructor.
    - A python callable: Use that very callable as constructor.
    - A string: Either a json/yaml filename or the name of a python
        module+class (e.g. "ray.rllib. [...] .[some class name]")

    Args:
        cls (class): The class to build an instance for (from `config`).
        config (Optional[dict,str]): The config dict or type-string or
            filename.

    Keyword Args:
        kwargs (any): Optional possibility to pass the constructor arguments in
            here and use `config` as the type-only info. Then we can call
            this like: from_config([type]?, [**kwargs for constructor])
            If `config` is already a dict, then `kwargs` will be merged
            with `config` (overwriting keys in `config`) after "type" has
            been popped out of `config`.
            If a constructor of a Configurable needs *args, the special
            key `_args` can be passed inside `kwargs` with a list value
            (e.g. kwargs={"_args": [arg1, arg2, arg3]}).

    Returns:
        any: The object generated from the config.
    """
    # `cls` is the config (config is None).
    if config is None and isinstance(cls, (dict, str)):
        config = cls
        cls = None
    # `config` is already a created object of this class ->
    # Take it as is.
    elif isinstance(cls, type) and isinstance(config, cls):
        return config

    # `type_`: Indicator for the Configurable's constructor.
    # `ctor_args`: *args arguments for the constructor.
    # `ctor_kwargs`: **kwargs arguments for the constructor.
    # Try to copy, so caller can reuse safely.
    try:
        config = deepcopy(config)
    except Exception:
        pass
    if isinstance(config, dict):
        type_ = config.pop("type", None)
        ctor_kwargs = config
        # Give kwargs priority over things defined in config dict.
        # This way, one can pass a generic `spec` and then override single
        # constructor parameters via the kwargs in the call to `from_config`.
        ctor_kwargs.update(kwargs)
    else:
        type_ = config
        if type_ is None and "type" in kwargs:
            type_ = kwargs.pop("type")
        ctor_kwargs = kwargs
    # Special `_args` field in kwargs for *args-utilizing constructors.
    ctor_args = force_list(ctor_kwargs.pop("_args", []))

    # Figure out the actual constructor (class) from `type_`.
    # None: Try __default__object (if no args/kwargs), only then
    # constructor of cls (using args/kwargs).
    if type_ is None:
        # We have a default constructor that was defined directly by cls
        # (not by its children).
        if cls is not None and hasattr(cls, "__default_constructor__") and \
                cls.__default_constructor__ is not None and \
                ctor_args == [] and \
                (
                        not hasattr(cls.__bases__[0],
                                    "__default_constructor__")
                        or
                        cls.__bases__[0].__default_constructor__ is None or
                        cls.__bases__[0].__default_constructor__ is not
                        cls.__default_constructor__
                ):
            constructor = cls.__default_constructor__
            # Default constructor's keywords into ctor_kwargs.
            if isinstance(constructor, partial):
                kwargs = merge_dicts(ctor_kwargs, constructor.keywords)
                constructor = partial(constructor.func, **kwargs)
                ctor_kwargs = {}  # erase to avoid duplicate kwarg error
        # No default constructor -> Try cls itself as constructor.
        else:
            constructor = cls
    # Try the __type_registry__ of this class.
    else:
        constructor = lookup_type(cls, type_)

        # Found in cls.__type_registry__.
        if constructor is not None:
            pass
        # type_ is False or None (and this value is not registered) ->
        # return value of type_.
        elif type_ is False or type_ is None:
            return type_
        # Python callable.
        elif callable(type_):
            constructor = type_
        # A string: Filename or a python module+class or a json/yaml str.
        elif isinstance(type_, str):
            if re.search("\.(yaml|yml|json)$", type_):
                return from_file(cls, type_, *ctor_args, **ctor_kwargs)
            # Try un-json/un-yaml'ing the string into a dict.
            obj = yaml.load(type_)
            if isinstance(obj, dict):
                return from_config(cls, obj)
            try:
                obj = from_config(cls, json.loads(type_))
            except json.JSONDecodeError:
                pass
            else:
                return obj

            # Test for absolute module.class specifier.
            if type_.find(".") != -1:
                module_name, function_name = type_.rsplit(".", 1)
                try:
                    module = importlib.import_module(module_name)
                    constructor = getattr(module, function_name)
                except (ModuleNotFoundError, ImportError):
                    pass
            # If constructor still not found, try attaching cls' module,
            # then look for type_ in there.
            if constructor is None:
                try:
                    module = importlib.import_module(cls.__module__)
                    constructor = getattr(module, type_)
                except (ModuleNotFoundError, ImportError, AttributeError):
                    # Try the package as well.
                    try:
                        package_name = importlib.import_module(
                            cls.__module__).__package__
                        module = __import__(package_name, fromlist=[type_])
                        constructor = getattr(module, type_)
                    except (ModuleNotFoundError, ImportError, AttributeError):
                        pass
            if constructor is None:
                raise ValueError(
                    "String specifier ({}) in `from_config` must be a "
                    "filename, a module+class, a class within '{}', or a key "
                    "into {}.__type_registry__!".format(
                        type_, cls.__module__, cls.__name__))

    if not constructor:
        raise TypeError(
            "Invalid type '{}'. Cannot create `from_config`.".format(type_))

    # Create object with inferred constructor.
    try:
        object_ = constructor(*ctor_args, **ctor_kwargs)
    # Catch attempts to construct from an abstract class and return None.
    except TypeError as e:
        if re.match("Can't instantiate abstract class", e.args[0]):
            return None
        raise e  # Re-raise
    # No sanity check for fake (lambda)-"constructors".
    if type(constructor).__name__ != "function":
        assert isinstance(
            object_, constructor.func
            if isinstance(constructor, partial) else constructor)

    return object_
예제 #10
0
def multi_from_logits(behaviour_policy_logits,
                      target_policy_logits,
                      actions,
                      discounts,
                      rewards,
                      values,
                      bootstrap_value,
                      dist_class,
                      model,
                      behaviour_action_log_probs=None,
                      clip_rho_threshold=1.0,
                      clip_pg_rho_threshold=1.0):
    """V-trace for softmax policies.

    Calculates V-trace actor critic targets for softmax polices as described in

    "IMPALA: Scalable Distributed Deep-RL with
    Importance Weighted Actor-Learner Architectures"
    by Espeholt, Soyer, Munos et al.

    Target policy refers to the policy we are interested in improving and
    behaviour policy refers to the policy that generated the given
    rewards and actions.

    In the notation used throughout documentation and comments, T refers to the
    time dimension ranging from 0 to T-1. B refers to the batch size and
    ACTION_SPACE refers to the list of numbers each representing a number of
    actions.

    Args:
        behaviour_policy_logits: A list with length of ACTION_SPACE of float32
            tensors of shapes [T, B, ACTION_SPACE[0]], ...,
            [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
            parameterizing the softmax behavior policy.
        target_policy_logits: A list with length of ACTION_SPACE of float32
            tensors of shapes [T, B, ACTION_SPACE[0]], ...,
            [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities
            parameterizing the softmax target policy.
        actions: A list with length of ACTION_SPACE of tensors of shapes
            [T, B, ...], ..., [T, B, ...]
            with actions sampled from the behavior policy.
        discounts: A float32 tensor of shape [T, B] with the discount
            encountered when following the behavior policy.
        rewards: A float32 tensor of shape [T, B] with the rewards generated by
            following the behavior policy.
        values: A float32 tensor of shape [T, B] with the value function
            estimates wrt. the target policy.
        bootstrap_value: A float32 of shape [B] with the value function
            estimate at time T.
        dist_class: action distribution class for the logits.
        model: backing ModelV2 instance
        behaviour_action_log_probs: Precalculated values of the behavior
            actions.
        clip_rho_threshold: A scalar float32 tensor with the clipping threshold
            for importance weights (rho) when calculating the baseline targets
            (vs). rho^bar in the paper.
        clip_pg_rho_threshold: A scalar float32 tensor with the clipping
            threshold on rho_s in:
            \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).

    Returns:
        A `VTraceFromLogitsReturns` namedtuple with the following fields:
        vs: A float32 tensor of shape [T, B]. Can be used as target to train a
            baseline (V(x_t) - vs_t)^2.
        pg_advantages: A float 32 tensor of shape [T, B]. Can be used as an
            estimate of the advantage in the calculation of policy gradients.
        log_rhos: A float32 tensor of shape [T, B] containing the log
            importance sampling weights (log rhos).
        behaviour_action_log_probs: A float32 tensor of shape [T, B] containing
            behaviour policy action log probabilities (log \mu(a_t)).
        target_action_log_probs: A float32 tensor of shape [T, B] containing
            target policy action probabilities (log \pi(a_t)).
    """

    behaviour_policy_logits = convert_to_torch_tensor(behaviour_policy_logits,
                                                      device="cpu")
    target_policy_logits = convert_to_torch_tensor(target_policy_logits,
                                                   device="cpu")
    actions = convert_to_torch_tensor(actions, device="cpu")

    # Make sure tensor ranks are as expected.
    # The rest will be checked by from_action_log_probs.
    for i in range(len(behaviour_policy_logits)):
        assert len(behaviour_policy_logits[i].size()) == 3
        assert len(target_policy_logits[i].size()) == 3

    target_action_log_probs = multi_log_probs_from_logits_and_actions(
        target_policy_logits, actions, dist_class, model)

    if (len(behaviour_policy_logits) > 1
            or behaviour_action_log_probs is None):
        # can't use precalculated values, recompute them. Note that
        # recomputing won't work well for autoregressive action dists
        # which may have variables not captured by 'logits'
        behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
            behaviour_policy_logits, actions, dist_class, model)

    behaviour_action_log_probs = convert_to_torch_tensor(
        behaviour_action_log_probs, device="cpu")
    behaviour_action_log_probs = force_list(behaviour_action_log_probs)
    log_rhos = get_log_rhos(target_action_log_probs,
                            behaviour_action_log_probs)

    vtrace_returns = from_importance_weights(
        log_rhos=log_rhos,
        discounts=discounts,
        rewards=rewards,
        values=values,
        bootstrap_value=bootstrap_value,
        clip_rho_threshold=clip_rho_threshold,
        clip_pg_rho_threshold=clip_pg_rho_threshold)

    return VTraceFromLogitsReturns(
        log_rhos=log_rhos,
        behaviour_action_log_probs=behaviour_action_log_probs,
        target_action_log_probs=target_action_log_probs,
        **vtrace_returns._asdict())
예제 #11
0
파일: tf_policy.py 프로젝트: parasj/ray
    def _initialize_loss(self, losses: List[TensorType],
                         loss_inputs: List[Tuple[str, TensorType]]) -> None:
        """Initializes the loss op from given loss tensor and placeholders.

        Args:
            loss (List[TensorType]): The list of loss ops returned by some
                loss function.
            loss_inputs (List[Tuple[str, TensorType]]): The list of Tuples:
                (name, tf1.placeholders) needed for calculating the loss.
        """
        self._loss_input_dict = dict(loss_inputs)
        self._loss_input_dict_no_rnn = {
            k: v
            for k, v in self._loss_input_dict.items()
            if (v not in self._state_inputs and v != self._seq_lens)
        }
        for i, ph in enumerate(self._state_inputs):
            self._loss_input_dict["state_in_{}".format(i)] = ph

        if self.model and not isinstance(self.model, tf.keras.Model):
            self._losses = force_list(
                self.model.custom_loss(losses, self._loss_input_dict))
            self._stats_fetches.update({"model": self.model.metrics()})
        else:
            self._losses = losses
        # Backward compatibility.
        self._loss = self._losses[0] if self._losses is not None else None

        if not self._optimizers:
            self._optimizers = force_list(self.optimizer())
            # Backward compatibility.
            self._optimizer = self._optimizers[0] if self._optimizers else None

        # Supporting more than one loss/optimizer.
        if self.config["_tf_policy_handles_more_than_one_loss"]:
            self._grads_and_vars = []
            self._grads = []
            for group in self.gradients(self._optimizers, self._losses):
                g_and_v = [(g, v) for (g, v) in group if g is not None]
                self._grads_and_vars.append(g_and_v)
                self._grads.append([g for (g, _) in g_and_v])
        # Only one optimizer and and loss term.
        else:
            self._grads_and_vars = [
                (g, v)
                for (g, v) in self.gradients(self._optimizer, self._loss)
                if g is not None
            ]
            self._grads = [g for (g, _) in self._grads_and_vars]

        if self.model:
            self._variables = ray.experimental.tf_utils.TensorFlowVariables(
                [], self.get_session(), self.variables())

        # Gather update ops for any batch norm layers.
        if len(self.devices) <= 1:
            if not self._update_ops:
                self._update_ops = tf1.get_collection(
                    tf1.GraphKeys.UPDATE_OPS,
                    scope=tf1.get_variable_scope().name)
            if self._update_ops:
                logger.info("Update ops to run on apply gradient: {}".format(
                    self._update_ops))
            with tf1.control_dependencies(self._update_ops):
                self._apply_op = self.build_apply_op(
                    optimizer=self._optimizers
                    if self.config["_tf_policy_handles_more_than_one_loss"]
                    else self._optimizer,
                    grads_and_vars=self._grads_and_vars,
                )

        if log_once("loss_used"):
            logger.debug("These tensors were used in the loss functions:"
                         f"\n{summarize(self._loss_input_dict)}\n")

        self.get_session().run(tf1.global_variables_initializer())

        # TensorFlowVariables holing a flat list of all our optimizers'
        # variables.
        self._optimizer_variables = ray.experimental.tf_utils.TensorFlowVariables(
            [v for o in self._optimizers for v in o.variables()],
            self.get_session())
예제 #12
0
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: TrainerConfigDict,
        *,
        max_seq_len: int = 20,
    ):
        """Initializes a TorchPolicy instance.

        Args:
            observation_space: Observation space of the policy.
            action_space: Action space of the policy.
            config: The Policy's config dict.
            max_seq_len: Max sequence length for LSTM training.
        """
        self.framework = config["framework"] = "torch"

        super().__init__(observation_space, action_space, config)

        # Create model.
        model, dist_class = self._init_model_and_dist_class()

        # Create multi-GPU model towers, if necessary.
        # - The central main model will be stored under self.model, residing
        #   on self.device (normally, a CPU).
        # - Each GPU will have a copy of that model under
        #   self.model_gpu_towers, matching the devices in self.devices.
        # - Parallelization is done by splitting the train batch and passing
        #   it through the model copies in parallel, then averaging over the
        #   resulting gradients, applying these averages on the main model and
        #   updating all towers' weights from the main model.
        # - In case of just one device (1 (fake or real) GPU or 1 CPU), no
        #   parallelization will be done.

        # Get devices to build the graph on.
        worker_idx = self.config.get("worker_index", 0)
        if not config["_fake_gpus"] and ray.worker._mode(
        ) == ray.worker.LOCAL_MODE:
            num_gpus = 0
        elif worker_idx == 0:
            num_gpus = config["num_gpus"]
        else:
            num_gpus = config["num_gpus_per_worker"]
        gpu_ids = list(range(torch.cuda.device_count()))

        # Place on one or more CPU(s) when either:
        # - Fake GPU mode.
        # - num_gpus=0 (either set by user or we are in local_mode=True).
        # - No GPUs available.
        if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
            logger.info("TorchPolicy (worker={}) running on {}.".format(
                worker_idx if worker_idx > 0 else "local",
                "{} fake-GPUs".format(num_gpus)
                if config["_fake_gpus"] else "CPU",
            ))
            self.device = torch.device("cpu")
            self.devices = [
                self.device for _ in range(int(math.ceil(num_gpus)) or 1)
            ]
            self.model_gpu_towers = [
                model if i == 0 else copy.deepcopy(model)
                for i in range(int(math.ceil(num_gpus)) or 1)
            ]
            if hasattr(self, "target_model"):
                self.target_models = {
                    m: self.target_model
                    for m in self.model_gpu_towers
                }
            self.model = model
        # Place on one or more actual GPU(s), when:
        # - num_gpus > 0 (set by user) AND
        # - local_mode=False AND
        # - actual GPUs available AND
        # - non-fake GPU mode.
        else:
            logger.info("TorchPolicy (worker={}) running on {} GPU(s).".format(
                worker_idx if worker_idx > 0 else "local", num_gpus))
            # We are a remote worker (WORKER_MODE=1):
            # GPUs should be assigned to us by ray.
            if ray.worker._mode() == ray.worker.WORKER_MODE:
                gpu_ids = ray.get_gpu_ids()

            if len(gpu_ids) < num_gpus:
                raise ValueError(
                    "TorchPolicy was not able to find enough GPU IDs! Found "
                    f"{gpu_ids}, but num_gpus={num_gpus}.")

            self.devices = [
                torch.device("cuda:{}".format(i))
                for i, id_ in enumerate(gpu_ids) if i < num_gpus
            ]
            self.device = self.devices[0]
            ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
            self.model_gpu_towers = []
            for i, _ in enumerate(ids):
                model_copy = copy.deepcopy(model)
                self.model_gpu_towers.append(model_copy.to(self.devices[i]))
            if hasattr(self, "target_model"):
                self.target_models = {
                    m: copy.deepcopy(self.target_model).to(self.devices[i])
                    for i, m in enumerate(self.model_gpu_towers)
                }
            self.model = self.model_gpu_towers[0]

        self.dist_class = dist_class
        self.unwrapped_model = model  # used to support DistributedDataParallel

        # Lock used for locking some methods on the object-level.
        # This prevents possible race conditions when calling the model
        # first, then its value function (e.g. in a loss function), in
        # between of which another model call is made (e.g. to compute an
        # action).
        self._lock = threading.RLock()

        self._state_inputs = self.model.get_initial_state()
        self._is_recurrent = len(self._state_inputs) > 0
        # Auto-update model's inference view requirements, if recurrent.
        self._update_model_view_requirements_from_init_state()
        # Combine view_requirements for Model and Policy.
        self.view_requirements.update(self.model.view_requirements)

        self.exploration = self._create_exploration()
        self._optimizers = force_list(self.optimizer())

        # Backward compatibility workaround so Policy will call self.loss() directly.
        # TODO(jungong): clean up after all policies are migrated to new sub-class
        # implementation.
        self._loss = None

        # Store, which params (by index within the model's list of
        # parameters) should be updated per optimizer.
        # Maps optimizer idx to set or param indices.
        self.multi_gpu_param_groups: List[Set[int]] = []
        main_params = {p: i for i, p in enumerate(self.model.parameters())}
        for o in self._optimizers:
            param_indices = []
            for pg_idx, pg in enumerate(o.param_groups):
                for p in pg["params"]:
                    param_indices.append(main_params[p])
            self.multi_gpu_param_groups.append(set(param_indices))

        # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
        # one with m towers (num_gpus).
        num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
        self._loaded_batches = [[] for _ in range(num_buffers)]

        # If set, means we are using distributed allreduce during learning.
        self.distributed_world_size = None

        self.batch_divisibility_req = self.get_batch_divisibility_req()
        self.max_seq_len = max_seq_len
예제 #13
0
파일: train.py 프로젝트: haochihlin/ray
def run(args, parser):
    if args.config_file:
        with open(args.config_file) as f:
            experiments = yaml.safe_load(f)
    else:
        # Note: keep this in sync with tune/config_parser.py
        experiments = {
            args.experiment_name: {  # i.e. log to ~/ray_results/default
                "run": args.run,
                "checkpoint_freq": args.checkpoint_freq,
                "checkpoint_at_end": args.checkpoint_at_end,
                "keep_checkpoints_num": args.keep_checkpoints_num,
                "checkpoint_score_attr": args.checkpoint_score_attr,
                "local_dir": args.local_dir,
                "resources_per_trial": (
                    args.resources_per_trial and
                    resources_to_json(args.resources_per_trial)),
                "stop": args.stop,
                "config": dict(args.config, env=args.env),
                "restore": args.restore,
                "num_samples": args.num_samples,
                "upload_dir": args.upload_dir,
            }
        }

    verbose = 1
    for exp in experiments.values():
        # Bazel makes it hard to find files specified in `args` (and `data`).
        # Look for them here.
        # NOTE: Some of our yaml files don't have a `config` section.
        input_ = exp.get("config", {}).get("input")
        if input_ and input_ != "sampler":
            inputs = force_list(input_)
            # This script runs in the ray/rllib dir.
            rllib_dir = Path(__file__).parent

            def patch_path(path):
                if os.path.exists(path):
                    return path
                else:
                    abs_path = str(rllib_dir.absolute().joinpath(path))
                    return abs_path if os.path.exists(abs_path) else path

            abs_inputs = list(map(patch_path, inputs))
            if not isinstance(input_, list):
                abs_inputs = abs_inputs[0]

            exp["config"]["input"] = abs_inputs

        if not exp.get("run"):
            parser.error("the following arguments are required: --run")
        if not exp.get("env") and not exp.get("config", {}).get("env"):
            parser.error("the following arguments are required: --env")

        if args.torch:
            deprecation_warning("--torch", "--framework=torch")
            exp["config"]["framework"] = "torch"
        elif args.eager:
            deprecation_warning("--eager", "--framework=[tf2|tfe]")
            exp["config"]["framework"] = "tfe"
        elif args.framework is not None:
            exp["config"]["framework"] = args.framework

        if args.trace:
            if exp["config"]["framework"] not in ["tf2", "tfe"]:
                raise ValueError("Must enable --eager to enable tracing.")
            exp["config"]["eager_tracing"] = True

        if args.v:
            exp["config"]["log_level"] = "INFO"
            verbose = 3  # Print details on trial result
        if args.vv:
            exp["config"]["log_level"] = "DEBUG"
            verbose = 3  # Print details on trial result

    if args.ray_num_nodes:
        # Import this only here so that train.py also works with
        # older versions (and user doesn't use `--ray-num-nodes`).
        from ray.cluster_utils import Cluster
        cluster = Cluster()
        for _ in range(args.ray_num_nodes):
            cluster.add_node(
                num_cpus=args.ray_num_cpus or 1,
                num_gpus=args.ray_num_gpus or 0,
                object_store_memory=args.ray_object_store_memory)
        ray.init(address=cluster.address)
    else:
        ray.init(
            include_dashboard=not args.no_ray_ui,
            address=args.ray_address,
            object_store_memory=args.ray_object_store_memory,
            num_cpus=args.ray_num_cpus,
            num_gpus=args.ray_num_gpus,
            local_mode=args.local_mode)

    if IS_NOTEBOOK:
        progress_reporter = JupyterNotebookReporter(
            overwrite=verbose >= 3, print_intermediate_tables=verbose >= 1)
    else:
        progress_reporter = CLIReporter(print_intermediate_tables=verbose >= 1)

    run_experiments(
        experiments,
        scheduler=create_scheduler(args.scheduler, **args.scheduler_config),
        resume=args.resume,
        queue_trials=args.queue_trials,
        verbose=verbose,
        progress_reporter=progress_reporter,
        concurrent=True)

    ray.shutdown()
예제 #14
0
    def learn_on_batch(self, postprocessed_batch):
        # Get batch ready for RNNs, if applicable.
        pad_batch_to_sequences_of_same_size(
            postprocessed_batch,
            max_seq_len=self.max_seq_len,
            shuffle=False,
            batch_divisibility_req=self.batch_divisibility_req)

        train_batch = self._lazy_tensor_dict(postprocessed_batch)
        loss_out = force_list(
            self._loss(self, self.model, self.dist_class, train_batch))
        assert len(loss_out) == len(self._optimizers)
        # assert not any(torch.isnan(l) for l in loss_out)

        # Loop through all optimizers.
        grad_info = {"allreduce_latency": 0.0}
        len_optim = len(self._optimizers)
        for i, opt in enumerate(self._optimizers):
            if i != len_optim - 1 & i != len_optim - 2:
                # Erase gradients in all vars of this optimizer.
                opt.zero_grad()
                # Recompute gradients of loss over all variables.
                loss_out[i].backward(
                    retain_graph=(i < len(self._optimizers) - 1))
                grad_info.update(self.extra_grad_process(opt, loss_out[i]))

                if self.distributed_world_size:
                    grads = []
                    for param_group in opt.param_groups:
                        for p in param_group["params"]:
                            if p.grad is not None:
                                grads.append(p.grad)

                    start = time.time()
                    if torch.cuda.is_available():
                        # Sadly, allreduce_coalesced does not work with CUDA yet.
                        for g in grads:
                            torch.distributed.all_reduce(
                                g, op=torch.distributed.ReduceOp.SUM)
                    else:
                        torch.distributed.all_reduce_coalesced(
                            grads, op=torch.distributed.ReduceOp.SUM)

                    for param_group in opt.param_groups:
                        for p in param_group["params"]:
                            if p.grad is not None:
                                p.grad /= self.distributed_world_size

                    grad_info["allreduce_latency"] += time.time() - start

                # Step the optimizer.
                opt.step()
        # handle ae_loss and encoder/decoder optims
        decoder_optimizer = self._optimizers[-1]
        encoder_optimizer = self._optimizers[-2]
        loss = loss_out[-1]
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        grad_info["allreduce_latency"] /= len(self._optimizers)
        grad_info.update(self.extra_grad_info(train_batch))
        return {LEARNER_STATS_KEY: grad_info}
예제 #15
0
    def __init__(self,
                 observation_space,
                 action_space,
                 config,
                 *,
                 model,
                 loss,
                 action_distribution_class,
                 action_sampler_fn=None,
                 action_distribution_fn=None,
                 max_seq_len=20,
                 get_batch_divisibility_req=None):
        """Build a policy from policy and loss torch modules.

        Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
        is set. Only single GPU is supported for now.

        Arguments:
            observation_space (gym.Space): observation space of the policy.
            action_space (gym.Space): action space of the policy.
            config (dict): The Policy config dict.
            model (nn.Module): PyTorch policy module. Given observations as
                input, this module must return a list of outputs where the
                first item is action logits, and the rest can be any value.
            loss (func): Function that takes (policy, model, dist_class,
                train_batch) and returns a single scalar loss.
            action_distribution_class (ActionDistribution): Class for action
                distribution.
            action_sampler_fn (Optional[callable]): A callable returning a
                sampled action and its log-likelihood given some (obs and
                state) inputs.
            action_distribution_fn (Optional[callable]): A callable returning
                distribution inputs (parameters), a dist-class to generate an
                action distribution object from, and internal-state outputs
                (or an empty list if not applicable).
                Note: No Exploration hooks have to be called from within
                `action_distribution_fn`. It's should only perform a simple
                forward pass through some model.
                If None, pass inputs through `self.model()` to get the
                distribution inputs.
            max_seq_len (int): Max sequence length for LSTM training.
            get_batch_divisibility_req (Optional[callable]): Optional callable
                that returns the divisibility requirement for sample batches.
        """
        self.framework = "torch"
        super().__init__(observation_space, action_space, config)
        self.device = (torch.device("cuda")
                       if torch.cuda.is_available() else torch.device("cpu"))
        self.model = model.to(self.device)
        self.exploration = self._create_exploration()
        self.unwrapped_model = model  # used to support DistributedDataParallel
        self._loss = loss
        self._optimizers = force_list(self.optimizer())

        self.dist_class = action_distribution_class
        self.action_sampler_fn = action_sampler_fn
        self.action_distribution_fn = action_distribution_fn

        # If set, means we are using distributed allreduce during learning.
        self.distributed_world_size = None

        self.max_seq_len = max_seq_len
        self.batch_divisibility_req = \
            get_batch_divisibility_req(self) if get_batch_divisibility_req \
            else 1
예제 #16
0
        def __init__(self, observation_space, action_space, config):
            assert tf.executing_eagerly()
            self.framework = config.get("framework", "tfe")
            Policy.__init__(self, observation_space, action_space, config)
            self._is_training = False
            self._loss_initialized = False
            self._sess = None

            self._loss = loss_fn
            self.batch_divisibility_req = get_batch_divisibility_req(self) if \
                callable(get_batch_divisibility_req) else \
                (get_batch_divisibility_req or 1)
            self._max_seq_len = config["model"]["max_seq_len"]

            if get_default_config:
                config = dict(get_default_config(), **config)

            if validate_spaces:
                validate_spaces(self, observation_space, action_space, config)

            if before_init:
                before_init(self, observation_space, action_space, config)

            self.config = config
            self.dist_class = None
            if action_sampler_fn or action_distribution_fn:
                if not make_model:
                    raise ValueError(
                        "`make_model` is required if `action_sampler_fn` OR "
                        "`action_distribution_fn` is given")
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"])

            if make_model:
                self.model = make_model(self, observation_space, action_space,
                                        config)
            else:
                self.model = ModelCatalog.get_model_v2(
                    observation_space,
                    action_space,
                    logit_dim,
                    config["model"],
                    framework=self.framework,
                )
            # Auto-update model's inference view requirements, if recurrent.
            self._update_model_inference_view_requirements_from_init_state()

            self.exploration = self._create_exploration()
            self._state_in = [
                tf.convert_to_tensor([s])
                for s in self.model.get_initial_state()
            ]

            # Combine view_requirements for Model and Policy.
            self.view_requirements.update(
                self.model.inference_view_requirements)

            if before_loss_init:
                before_loss_init(self, observation_space, action_space, config)

            if optimizer_fn:
                optimizers = optimizer_fn(self, config)
            else:
                optimizers = tf.keras.optimizers.Adam(config["lr"])
            optimizers = force_list(optimizers)
            if getattr(self, "exploration", None):
                optimizers = self.exploration.get_exploration_optimizer(
                    optimizers)
            # TODO: (sven) Allow tf policy to have more than 1 optimizer.
            #  Just like torch Policy does.
            self._optimizer = optimizers[0] if optimizers else None

            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=stats_fn,
            )
            self._loss_initialized = True

            if after_init:
                after_init(self, observation_space, action_space, config)

            # Got to reset global_timestep again after fake run-throughs.
            self.global_timestep = 0
예제 #17
0
    def __init__(
            self,
            policy_map: Dict[PolicyID, Policy],
            callbacks: "DefaultCallbacks",
            # TODO: (sven) make `num_agents` flexibly grow in size.
            num_agents: int = 100,
            num_timesteps=None,
            time_major: Optional[bool] = False):
        """Initializes a _MultiAgentSampleCollector object.

        Args:
            policy_map (Dict[PolicyID,Policy]): Maps policy ids to policy
                instances.
            callbacks (DefaultCallbacks): RLlib callbacks (configured in the
                Trainer config dict). Used for trajectory postprocessing event.
            num_agents (int): The max number of agent slots to pre-allocate
                in the buffer.
            num_timesteps (int): The max number of timesteps to pre-allocate
                in the buffer.
            time_major (Optional[bool]): Whether to preallocate buffers and
                collect samples in time-major fashion (TxBx...).
        """

        self.policy_map = policy_map
        self.callbacks = callbacks
        if num_agents == float("inf") or num_agents is None:
            num_agents = 1000
        self.num_agents = int(num_agents)

        # Collect SampleBatches per-policy in _PerPolicySampleCollectors.
        self.policy_sample_collectors = {}
        for pid, policy in policy_map.items():
            # Figure out max-shifts (before and after).
            view_reqs = policy.training_view_requirements
            max_shift_before = 0
            max_shift_after = 0
            for vr in view_reqs.values():
                shift = force_list(vr.shift)
                if max_shift_before > shift[0]:
                    max_shift_before = shift[0]
                if max_shift_after < shift[-1]:
                    max_shift_after = shift[-1]
            # Figure out num_timesteps and num_agents.
            kwargs = {"time_major": time_major}
            if policy.is_recurrent():
                kwargs["num_timesteps"] = \
                    policy.config["model"]["max_seq_len"]
                kwargs["time_major"] = True
            elif num_timesteps is not None:
                kwargs["num_timesteps"] = num_timesteps

            self.policy_sample_collectors[pid] = _PerPolicySampleCollector(
                num_agents=self.num_agents,
                shift_before=-max_shift_before,
                shift_after=max_shift_after,
                **kwargs)

        # Internal agent-to-policy map.
        self.agent_to_policy = {}
        # Number of "inference" steps taken in the environment.
        # Regardless of the number of agents involved in each of these steps.
        self.count = 0
예제 #18
0
    def compute_gradients(self,
                          postprocessed_batch: SampleBatch) -> ModelGradients:

        pad_batch_to_sequences_of_same_size(
            postprocessed_batch,
            max_seq_len=self.max_seq_len,
            shuffle=False,
            batch_divisibility_req=self.batch_divisibility_req,
            view_requirements=self.view_requirements,
        )

        # Mark the batch as "is_training" so the Model can use this
        # information.
        postprocessed_batch["is_training"] = True
        train_batch = self._lazy_tensor_dict(postprocessed_batch)

        # Calculate the actual policy loss.
        loss_out = force_list(
            self._loss(self, self.model, self.dist_class, train_batch))

        # Call Model's custom-loss with Policy loss outputs and train_batch.
        if self.model:
            loss_out = self.model.custom_loss(loss_out, train_batch)

        # Give Exploration component that chance to modify the loss (or add
        # its own terms).
        if hasattr(self, "exploration"):
            loss_out = self.exploration.get_exploration_loss(
                loss_out, train_batch)

        assert len(loss_out) == len(self._optimizers)

        # assert not any(torch.isnan(l) for l in loss_out)
        fetches = self.extra_compute_grad_fetches()

        # Loop through all optimizers.
        grad_info = {"allreduce_latency": 0.0}

        all_grads = []
        for i, opt in enumerate(self._optimizers):
            # Erase gradients in all vars of this optimizer.
            opt.zero_grad()
            # Recompute gradients of loss over all variables.
            loss_out[i].backward(retain_graph=(i < len(self._optimizers) - 1))
            grad_info.update(self.extra_grad_process(opt, loss_out[i]))

            grads = []
            # Note that return values are just references;
            # Calling zero_grad would modify the values.
            for param_group in opt.param_groups:
                for p in param_group["params"]:
                    if p.grad is not None:
                        grads.append(p.grad)
                        all_grads.append(p.grad.data.cpu().numpy())
                    else:
                        all_grads.append(None)

            if self.distributed_world_size:
                start = time.time()
                if torch.cuda.is_available():
                    # Sadly, allreduce_coalesced does not work with CUDA yet.
                    for g in grads:
                        torch.distributed.all_reduce(
                            g, op=torch.distributed.ReduceOp.SUM)
                else:
                    torch.distributed.all_reduce_coalesced(
                        grads, op=torch.distributed.ReduceOp.SUM)

                for param_group in opt.param_groups:
                    for p in param_group["params"]:
                        if p.grad is not None:
                            p.grad /= self.distributed_world_size

                grad_info["allreduce_latency"] += time.time() - start

        grad_info["allreduce_latency"] /= len(self._optimizers)
        grad_info.update(self.extra_grad_info(train_batch))

        return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
예제 #19
0
        def __init__(self, observation_space, action_space, config):
            assert tf.executing_eagerly()
            self.framework = config.get("framework", "tfe")
            Policy.__init__(self, observation_space, action_space, config)
            self._is_training = False
            self._loss_initialized = False
            self._sess = None

            self._loss = loss_fn
            self.batch_divisibility_req = get_batch_divisibility_req(self) if \
                callable(get_batch_divisibility_req) else \
                (get_batch_divisibility_req or 1)
            self._max_seq_len = config["model"]["max_seq_len"]

            if get_default_config:
                config = dict(get_default_config(), **config)

            if validate_spaces:
                validate_spaces(self, observation_space, action_space, config)

            if before_init:
                before_init(self, observation_space, action_space, config)

            self.config = config
            self.dist_class = None
            if action_sampler_fn or action_distribution_fn:
                if not make_model:
                    raise ValueError(
                        "`make_model` is required if `action_sampler_fn` OR "
                        "`action_distribution_fn` is given")
            else:
                self.dist_class, logit_dim = ModelCatalog.get_action_dist(
                    action_space, self.config["model"])

            if make_model:
                self.model = make_model(self, observation_space, action_space,
                                        config)
            else:
                self.model = ModelCatalog.get_model_v2(
                    observation_space,
                    action_space,
                    logit_dim,
                    config["model"],
                    framework=self.framework,
                )
            # Lock used for locking some methods on the object-level.
            # This prevents possible race conditions when calling the model
            # first, then its value function (e.g. in a loss function), in
            # between of which another model call is made (e.g. to compute an
            # action).
            self._lock = threading.RLock()

            # Auto-update model's inference view requirements, if recurrent.
            self._update_model_view_requirements_from_init_state()

            self.exploration = self._create_exploration()
            self._state_inputs = self.model.get_initial_state()
            self._is_recurrent = len(self._state_inputs) > 0

            # Combine view_requirements for Model and Policy.
            self.view_requirements.update(self.model.view_requirements)

            if before_loss_init:
                before_loss_init(self, observation_space, action_space, config)

            if optimizer_fn:
                optimizers = optimizer_fn(self, config)
            else:
                optimizers = tf.keras.optimizers.Adam(config["lr"])
            optimizers = force_list(optimizers)
            if getattr(self, "exploration", None):
                optimizers = self.exploration.get_exploration_optimizer(
                    optimizers)
            # TODO: (sven) Allow tf policy to have more than 1 optimizer.
            #  Just like torch Policy does.
            self._optimizer = optimizers[0] if optimizers else None

            self._initialize_loss_from_dummy_batch(
                auto_remove_unneeded_view_reqs=True,
                stats_fn=stats_fn,
            )
            self._loss_initialized = True

            if after_init:
                after_init(self, observation_space, action_space, config)

            # Got to reset global_timestep again after fake run-throughs.
            self.global_timestep = 0
예제 #20
0
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: TrainerConfigDict,
        *,
        model: ModelV2,
        loss: Callable[
            [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
            Union[TensorType, List[TensorType]]],
        action_distribution_class: Type[TorchDistributionWrapper],
        action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]],
                                             Tuple[TensorType,
                                                   TensorType]]] = None,
        action_distribution_fn: Optional[
            Callable[[Policy, ModelV2, TensorType, TensorType, TensorType],
                     Tuple[TensorType, Type[TorchDistributionWrapper],
                           List[TensorType]]]] = None,
        max_seq_len: int = 20,
        get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
    ):
        """Build a policy from policy and loss torch modules.

        Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
        is set. Only single GPU is supported for now.

        Args:
            observation_space (gym.spaces.Space): observation space of the
                policy.
            action_space (gym.spaces.Space): action space of the policy.
            config (TrainerConfigDict): The Policy config dict.
            model (ModelV2): PyTorch policy module. Given observations as
                input, this module must return a list of outputs where the
                first item is action logits, and the rest can be any value.
            loss (Callable[[Policy, ModelV2, Type[TorchDistributionWrapper],
                SampleBatch], Union[TensorType, List[TensorType]]]): Callable
                that returns a single scalar loss or a list of loss terms.
            action_distribution_class (Type[TorchDistributionWrapper]): Class
                for a torch action distribution.
            action_sampler_fn (Callable[[TensorType, List[TensorType]],
                Tuple[TensorType, TensorType]]): A callable returning a
                sampled action and its log-likelihood given Policy, ModelV2,
                input_dict, explore, timestep, and is_training.
            action_distribution_fn (Optional[Callable[[Policy, ModelV2,
                Dict[str, TensorType], TensorType, TensorType],
                Tuple[TensorType, type, List[TensorType]]]]): A callable
                returning distribution inputs (parameters), a dist-class to
                generate an action distribution object from, and
                internal-state outputs (or an empty list if not applicable).
                Note: No Exploration hooks have to be called from within
                `action_distribution_fn`. It's should only perform a simple
                forward pass through some model.
                If None, pass inputs through `self.model()` to get distribution
                inputs.
                The callable takes as inputs: Policy, ModelV2, input_dict,
                explore, timestep, is_training.
            max_seq_len (int): Max sequence length for LSTM training.
            get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
                Optional callable that returns the divisibility requirement
                for sample batches given the Policy.
        """
        self.framework = "torch"
        super().__init__(observation_space, action_space, config)
        if torch.cuda.is_available():
            logger.info("TorchPolicy running on GPU.")
            self.device = torch.device("cuda")
        else:
            logger.info("TorchPolicy running on CPU.")
            self.device = torch.device("cpu")
        self.model = model.to(self.device)

        # Lock used for locking some methods on the object-level.
        # This prevents possible race conditions when calling the model
        # first, then its value function (e.g. in a loss function), in
        # between of which another model call is made (e.g. to compute an
        # action).
        self._lock = threading.RLock()

        self._state_inputs = self.model.get_initial_state()
        self._is_recurrent = len(self._state_inputs) > 0
        # Auto-update model's inference view requirements, if recurrent.
        self._update_model_view_requirements_from_init_state()
        # Combine view_requirements for Model and Policy.
        self.view_requirements.update(self.model.view_requirements)

        self.exploration = self._create_exploration()
        self.unwrapped_model = model  # used to support DistributedDataParallel
        self._loss = loss
        self._optimizers = force_list(self.optimizer())

        self.dist_class = action_distribution_class
        self.action_sampler_fn = action_sampler_fn
        self.action_distribution_fn = action_distribution_fn

        # If set, means we are using distributed allreduce during learning.
        self.distributed_world_size = None

        self.max_seq_len = max_seq_len
        self.batch_divisibility_req = get_batch_divisibility_req(self) if \
            callable(get_batch_divisibility_req) else \
            (get_batch_divisibility_req or 1)
예제 #21
0
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: TrainerConfigDict,
        *,
        model: Optional[TorchModelV2] = None,
        loss: Optional[Callable[
            [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
            Union[TensorType, List[TensorType]], ]] = None,
        action_distribution_class: Optional[
            Type[TorchDistributionWrapper]] = None,
        action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]],
                                             Tuple[TensorType,
                                                   TensorType]]] = None,
        action_distribution_fn: Optional[
            Callable[[Policy, ModelV2, TensorType, TensorType, TensorType],
                     Tuple[TensorType, Type[TorchDistributionWrapper],
                           List[TensorType]], ]] = None,
        max_seq_len: int = 20,
        get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
    ):
        """Initializes a TorchPolicy instance.

        Args:
            observation_space: Observation space of the policy.
            action_space: Action space of the policy.
            config: The Policy's config dict.
            model: PyTorch policy module. Given observations as
                input, this module must return a list of outputs where the
                first item is action logits, and the rest can be any value.
            loss: Callable that returns one or more (a list of) scalar loss
                terms.
            action_distribution_class: Class for a torch action distribution.
            action_sampler_fn: A callable returning a sampled action and its
                log-likelihood given Policy, ModelV2, input_dict, state batches
                (optional), explore, and timestep.
                Provide `action_sampler_fn` if you would like to have full
                control over the action computation step, including the
                model forward pass, possible sampling from a distribution,
                and exploration logic.
                Note: If `action_sampler_fn` is given, `action_distribution_fn`
                must be None. If both `action_sampler_fn` and
                `action_distribution_fn` are None, RLlib will simply pass
                inputs through `self.model` to get distribution inputs, create
                the distribution object, sample from it, and apply some
                exploration logic to the results.
                The callable takes as inputs: Policy, ModelV2, input_dict
                (SampleBatch), state_batches (optional), explore, and timestep.
            action_distribution_fn: A callable returning distribution inputs
                (parameters), a dist-class to generate an action distribution
                object from, and internal-state outputs (or an empty list if
                not applicable).
                Provide `action_distribution_fn` if you would like to only
                customize the model forward pass call. The resulting
                distribution parameters are then used by RLlib to create a
                distribution object, sample from it, and execute any
                exploration logic.
                Note: If `action_distribution_fn` is given, `action_sampler_fn`
                must be None. If both `action_sampler_fn` and
                `action_distribution_fn` are None, RLlib will simply pass
                inputs through `self.model` to get distribution inputs, create
                the distribution object, sample from it, and apply some
                exploration logic to the results.
                The callable takes as inputs: Policy, ModelV2, ModelInputDict,
                explore, timestep, is_training.
            max_seq_len: Max sequence length for LSTM training.
            get_batch_divisibility_req: Optional callable that returns the
                divisibility requirement for sample batches given the Policy.
        """
        self.framework = config["framework"] = "torch"
        super().__init__(observation_space, action_space, config)

        # Create multi-GPU model towers, if necessary.
        # - The central main model will be stored under self.model, residing
        #   on self.device (normally, a CPU).
        # - Each GPU will have a copy of that model under
        #   self.model_gpu_towers, matching the devices in self.devices.
        # - Parallelization is done by splitting the train batch and passing
        #   it through the model copies in parallel, then averaging over the
        #   resulting gradients, applying these averages on the main model and
        #   updating all towers' weights from the main model.
        # - In case of just one device (1 (fake or real) GPU or 1 CPU), no
        #   parallelization will be done.

        # If no Model is provided, build a default one here.
        if model is None:
            dist_class, logit_dim = ModelCatalog.get_action_dist(
                action_space, self.config["model"], framework=self.framework)
            model = ModelCatalog.get_model_v2(
                obs_space=self.observation_space,
                action_space=self.action_space,
                num_outputs=logit_dim,
                model_config=self.config["model"],
                framework=self.framework,
            )
            if action_distribution_class is None:
                action_distribution_class = dist_class

        # Get devices to build the graph on.
        worker_idx = self.config.get("worker_index", 0)
        if not config["_fake_gpus"] and ray.worker._mode(
        ) == ray.worker.LOCAL_MODE:
            num_gpus = 0
        elif worker_idx == 0:
            num_gpus = config["num_gpus"]
        else:
            num_gpus = config["num_gpus_per_worker"]
        gpu_ids = list(range(torch.cuda.device_count()))

        # Place on one or more CPU(s) when either:
        # - Fake GPU mode.
        # - num_gpus=0 (either set by user or we are in local_mode=True).
        # - No GPUs available.
        if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
            logger.info("TorchPolicy (worker={}) running on {}.".format(
                worker_idx if worker_idx > 0 else "local",
                "{} fake-GPUs".format(num_gpus)
                if config["_fake_gpus"] else "CPU",
            ))
            self.device = torch.device("cpu")
            self.devices = [
                self.device for _ in range(int(math.ceil(num_gpus)) or 1)
            ]
            self.model_gpu_towers = [
                model if i == 0 else copy.deepcopy(model)
                for i in range(int(math.ceil(num_gpus)) or 1)
            ]
            if hasattr(self, "target_model"):
                self.target_models = {
                    m: self.target_model
                    for m in self.model_gpu_towers
                }
            self.model = model
        # Place on one or more actual GPU(s), when:
        # - num_gpus > 0 (set by user) AND
        # - local_mode=False AND
        # - actual GPUs available AND
        # - non-fake GPU mode.
        else:
            logger.info("TorchPolicy (worker={}) running on {} GPU(s).".format(
                worker_idx if worker_idx > 0 else "local", num_gpus))
            # We are a remote worker (WORKER_MODE=1):
            # GPUs should be assigned to us by ray.
            if ray.worker._mode() == ray.worker.WORKER_MODE:
                gpu_ids = ray.get_gpu_ids()

            if len(gpu_ids) < num_gpus:
                raise ValueError(
                    "TorchPolicy was not able to find enough GPU IDs! Found "
                    f"{gpu_ids}, but num_gpus={num_gpus}.")

            self.devices = [
                torch.device("cuda:{}".format(i))
                for i, id_ in enumerate(gpu_ids) if i < num_gpus
            ]
            self.device = self.devices[0]
            ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
            self.model_gpu_towers = []
            for i, _ in enumerate(ids):
                model_copy = copy.deepcopy(model)
                self.model_gpu_towers.append(model_copy.to(self.devices[i]))
            if hasattr(self, "target_model"):
                self.target_models = {
                    m: copy.deepcopy(self.target_model).to(self.devices[i])
                    for i, m in enumerate(self.model_gpu_towers)
                }
            self.model = self.model_gpu_towers[0]

        # Lock used for locking some methods on the object-level.
        # This prevents possible race conditions when calling the model
        # first, then its value function (e.g. in a loss function), in
        # between of which another model call is made (e.g. to compute an
        # action).
        self._lock = threading.RLock()

        self._state_inputs = self.model.get_initial_state()
        self._is_recurrent = len(self._state_inputs) > 0
        # Auto-update model's inference view requirements, if recurrent.
        self._update_model_view_requirements_from_init_state()
        # Combine view_requirements for Model and Policy.
        self.view_requirements.update(self.model.view_requirements)

        self.exploration = self._create_exploration()
        self.unwrapped_model = model  # used to support DistributedDataParallel
        # To ensure backward compatibility:
        # Old way: If `loss` provided here, use as-is (as a function).
        if loss is not None:
            self._loss = loss
        # New way: Convert the overridden `self.loss` into a plain function,
        # so it can be called the same way as `loss` would be, ensuring
        # backward compatibility.
        elif self.loss.__func__.__qualname__ != "Policy.loss":
            self._loss = self.loss.__func__
        # `loss` not provided nor overridden from Policy -> Set to None.
        else:
            self._loss = None
        self._optimizers = force_list(self.optimizer())
        # Store, which params (by index within the model's list of
        # parameters) should be updated per optimizer.
        # Maps optimizer idx to set or param indices.
        self.multi_gpu_param_groups: List[Set[int]] = []
        main_params = {p: i for i, p in enumerate(self.model.parameters())}
        for o in self._optimizers:
            param_indices = []
            for pg_idx, pg in enumerate(o.param_groups):
                for p in pg["params"]:
                    param_indices.append(main_params[p])
            self.multi_gpu_param_groups.append(set(param_indices))

        # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
        # one with m towers (num_gpus).
        num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
        self._loaded_batches = [[] for _ in range(num_buffers)]

        self.dist_class = action_distribution_class
        self.action_sampler_fn = action_sampler_fn
        self.action_distribution_fn = action_distribution_fn

        # If set, means we are using distributed allreduce during learning.
        self.distributed_world_size = None

        self.max_seq_len = max_seq_len
        self.batch_divisibility_req = (get_batch_divisibility_req(self)
                                       if callable(get_batch_divisibility_req)
                                       else (get_batch_divisibility_req or 1))
예제 #22
0
    def learn_on_batch(
            self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
        # Get batch ready for RNNs, if applicable.
        pad_batch_to_sequences_of_same_size(
            postprocessed_batch,
            max_seq_len=self.max_seq_len,
            shuffle=False,
            batch_divisibility_req=self.batch_divisibility_req)

        train_batch = self._lazy_tensor_dict(postprocessed_batch)
        loss_out = force_list(
            self._loss(self, self.model, self.dist_class, train_batch))
        # Call Model's custom-loss with Policy loss outputs and train_batch.
        if self.model:
            loss_out = self.model.custom_loss(loss_out, train_batch)
        # Modifies the loss as specified by the Exploration strategy.
        if hasattr(self, "exploration"):
            loss_out = self.exploration.get_exploration_loss(
                loss_out, train_batch)
        assert len(loss_out) == len(self._optimizers)
        # assert not any(torch.isnan(l) for l in loss_out)
        fetches = self.extra_compute_grad_fetches()

        # Loop through all optimizers.
        grad_info = {"allreduce_latency": 0.0}
        for i, opt in enumerate(self._optimizers):
            # Erase gradients in all vars of this optimizer.
            opt.zero_grad()
            # Recompute gradients of loss over all variables.
            loss_out[i].backward(retain_graph=(i < len(self._optimizers) - 1))
            grad_info.update(self.extra_grad_process(opt, loss_out[i]))

            if self.distributed_world_size:
                grads = []
                for param_group in opt.param_groups:
                    for p in param_group["params"]:
                        if p.grad is not None:
                            grads.append(p.grad)

                start = time.time()
                if torch.cuda.is_available():
                    # Sadly, allreduce_coalesced does not work with CUDA yet.
                    for g in grads:
                        torch.distributed.all_reduce(
                            g, op=torch.distributed.ReduceOp.SUM)
                else:
                    torch.distributed.all_reduce_coalesced(
                        grads, op=torch.distributed.ReduceOp.SUM)

                for param_group in opt.param_groups:
                    for p in param_group["params"]:
                        if p.grad is not None:
                            p.grad /= self.distributed_world_size

                grad_info["allreduce_latency"] += time.time() - start

            # Step the optimizer.
            opt.step()

        grad_info["allreduce_latency"] /= len(self._optimizers)
        grad_info.update(self.extra_grad_info(train_batch))
        return dict(fetches, **{LEARNER_STATS_KEY: grad_info})
예제 #23
0
        def _compute_gradients_helper(self, samples):
            """Computes and returns grads as eager tensors."""

            # Increase the tracing counter to make sure we don't re-trace too
            # often. If eager_tracing=True, this counter should only get
            # incremented during the @tf.function trace operations, never when
            # calling the already traced function after that.
            self._re_trace_counter += 1

            # Gather all variables for which to calculate losses.
            if isinstance(self.model, tf.keras.Model):
                variables = self.model.trainable_variables
            else:
                variables = self.model.trainable_variables()

            # Calculate the loss(es) inside a tf GradientTape.
            with tf.GradientTape(
                    persistent=compute_gradients_fn is not None) as tape:
                losses = self._loss(self, self.model, self.dist_class, samples)
            losses = force_list(losses)

            # User provided a compute_gradients_fn.
            if compute_gradients_fn:
                # Wrap our tape inside a wrapper, such that the resulting
                # object looks like a "classic" tf.optimizer. This way, custom
                # compute_gradients_fn will work on both tf static graph
                # and tf-eager.
                optimizer = OptimizerWrapper(tape)
                # More than one loss terms/optimizers.
                if self.config["_tf_policy_handles_more_than_one_loss"]:
                    grads_and_vars = compute_gradients_fn(
                        self, [optimizer] * len(losses), losses)
                # Only one loss and one optimizer.
                else:
                    grads_and_vars = [
                        compute_gradients_fn(self, optimizer, losses[0])
                    ]
            # Default: Compute gradients using the above tape.
            else:
                grads_and_vars = [
                    list(zip(tape.gradient(loss, variables), variables))
                    for loss in losses
                ]

            if log_once("grad_vars"):
                for g_and_v in grads_and_vars:
                    for g, v in g_and_v:
                        if g is not None:
                            logger.info(f"Optimizing variable {v.name}")

            # `grads_and_vars` is returned a list (len=num optimizers/losses)
            # of lists of (grad, var) tuples.
            if self.config["_tf_policy_handles_more_than_one_loss"]:
                grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
            # `grads_and_vars` is returned as a list of (grad, var) tuples.
            else:
                grads_and_vars = grads_and_vars[0]
                grads = [g for g, _ in grads_and_vars]

            stats = self._stats(self, samples, grads)
            return grads_and_vars, grads, stats
예제 #24
0
파일: torch_policy.py 프로젝트: yangysc/ray
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: TrainerConfigDict,
        *,
        model: ModelV2,
        loss: Callable[
            [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
            Union[TensorType, List[TensorType]]],
        action_distribution_class: Type[TorchDistributionWrapper],
        action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]],
                                             Tuple[TensorType,
                                                   TensorType]]] = None,
        action_distribution_fn: Optional[
            Callable[[Policy, ModelV2, TensorType, TensorType, TensorType],
                     Tuple[TensorType, Type[TorchDistributionWrapper],
                           List[TensorType]]]] = None,
        max_seq_len: int = 20,
        get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
    ):
        """Build a policy from policy and loss torch modules.

        Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
        is set. Only single GPU is supported for now.

        Args:
            observation_space (gym.spaces.Space): observation space of the
                policy.
            action_space (gym.spaces.Space): action space of the policy.
            config (TrainerConfigDict): The Policy config dict.
            model (ModelV2): PyTorch policy module. Given observations as
                input, this module must return a list of outputs where the
                first item is action logits, and the rest can be any value.
            loss (Callable[[Policy, ModelV2, Type[TorchDistributionWrapper],
                SampleBatch], Union[TensorType, List[TensorType]]]): Callable
                that returns a single scalar loss or a list of loss terms.
            action_distribution_class (Type[TorchDistributionWrapper]): Class
                for a torch action distribution.
            action_sampler_fn (Callable[[TensorType, List[TensorType]],
                Tuple[TensorType, TensorType]]): A callable returning a
                sampled action and its log-likelihood given Policy, ModelV2,
                input_dict, explore, timestep, and is_training.
            action_distribution_fn (Optional[Callable[[Policy, ModelV2,
                ModelInputDict, TensorType, TensorType],
                Tuple[TensorType, type, List[TensorType]]]]): A callable
                returning distribution inputs (parameters), a dist-class to
                generate an action distribution object from, and
                internal-state outputs (or an empty list if not applicable).
                Note: No Exploration hooks have to be called from within
                `action_distribution_fn`. It's should only perform a simple
                forward pass through some model.
                If None, pass inputs through `self.model()` to get distribution
                inputs.
                The callable takes as inputs: Policy, ModelV2, ModelInputDict,
                explore, timestep, is_training.
            max_seq_len (int): Max sequence length for LSTM training.
            get_batch_divisibility_req (Optional[Callable[[Policy], int]]]):
                Optional callable that returns the divisibility requirement
                for sample batches given the Policy.
        """
        self.framework = "torch"
        super().__init__(observation_space, action_space, config)

        # Create multi-GPU model towers, if necessary.
        # - The central main model will be stored under self.model, residing on
        #   self.device.
        # - Each GPU will have a copy of that model under
        #   self.model_gpu_towers, matching the devices in self.devices.
        # - Parallelization is done by splitting the train batch and passing
        #   it through the model copies in parallel, then averaging over the
        #   resulting gradients, applying these averages on the main model and
        #   updating all towers' weights from the main model.
        # - In case of just one device (1 (fake) GPU or 1 CPU), no
        #   parallelization will be done.
        if config["_fake_gpus"] or config["num_gpus"] == 0 or \
                not torch.cuda.is_available():
            logger.info(
                "TorchPolicy running on {}.".format("{} fake-GPUs".format(
                    config["num_gpus"]) if config["_fake_gpus"] else "CPU"))
            self.device = torch.device("cpu")
            self.devices = [
                self.device for _ in range(config["num_gpus"] or 1)
            ]
            self.model_gpu_towers = [
                model if config["num_gpus"] == 0 else copy.deepcopy(model)
                for i in range(config["num_gpus"] or 1)
            ]
        else:
            logger.info("TorchPolicy running on {} GPU(s).".format(
                config["num_gpus"]))
            self.device = torch.device("cuda")
            self.devices = [
                torch.device("cuda:{}".format(id_))
                for i, id_ in enumerate(ray.get_gpu_ids())
                if i < config["num_gpus"]
            ]
            self.model_gpu_towers = nn.parallel.replicate.replicate(
                model, [
                    id_ for i, id_ in enumerate(ray.get_gpu_ids())
                    if i < config["num_gpus"]
                ])
        self.model = model.to(self.device)

        # Lock used for locking some methods on the object-level.
        # This prevents possible race conditions when calling the model
        # first, then its value function (e.g. in a loss function), in
        # between of which another model call is made (e.g. to compute an
        # action).
        self._lock = threading.RLock()

        self._state_inputs = self.model.get_initial_state()
        self._is_recurrent = len(self._state_inputs) > 0
        # Auto-update model's inference view requirements, if recurrent.
        self._update_model_view_requirements_from_init_state()
        # Combine view_requirements for Model and Policy.
        self.view_requirements.update(self.model.view_requirements)

        self.exploration = self._create_exploration()
        self.unwrapped_model = model  # used to support DistributedDataParallel
        self._loss = loss
        self._optimizers = force_list(self.optimizer())
        # Store, which params (by index within the model's list of
        # parameters) should be updated per optimizer.
        # Maps optimizer idx to set or param indices.
        self.multi_gpu_param_groups: List[Set[int]] = []
        main_params = {p: i for i, p in enumerate(self.model.parameters())}
        for o in self._optimizers:
            param_indices = []
            for pg_idx, pg in enumerate(o.param_groups):
                for p in pg["params"]:
                    param_indices.append(main_params[p])
            self.multi_gpu_param_groups.append(set(param_indices))

        self.dist_class = action_distribution_class
        self.action_sampler_fn = action_sampler_fn
        self.action_distribution_fn = action_distribution_fn

        # If set, means we are using distributed allreduce during learning.
        self.distributed_world_size = None

        self.max_seq_len = max_seq_len
        self.batch_divisibility_req = get_batch_divisibility_req(self) if \
            callable(get_batch_divisibility_req) else \
            (get_batch_divisibility_req or 1)