예제 #1
0
    def save(self, model_dir: str, network_map: Dict[str, str] = None,
             version: int = 0):
        """
        Save models.

        An example of network map::

            {"restorable_model_1": "file_name_1",
             "restorable_model_2": "file_name_2"}

        Get keys by calling ``<Class name>.get_restorable()``

        Args:
            model_dir: Save directory.
            network_map: Key is module name, value is saved name.
            version: Version number of the new save.
        """
        network_map = {} if network_map is None else network_map
        if version == -1:
            version = "default"
            default_logger.warning(
                "You are using the default version to save, "
                "use custom version instead.")
        for r in self._is_restorable:
            if r in network_map:
                t.save(getattr(self, r),
                       join(model_dir,
                            "{}_{}.pt".format(network_map[r], version)))
            else:
                default_logger.warning("Save name for module \"{}\" is not "
                                       "specified, module name is used."
                                       .format(r))
                t.save(getattr(self, r),
                       join(model_dir,
                            "/{}_{}.pt".format(r, version)))
예제 #2
0
파일: base.py 프로젝트: ikamensh/machin
    def load(self,
             model_dir: str,
             network_map: Dict[str, str] = None,
             version: int = -1):
        """
        Load models.

        An example of network map::

            {"restorable_model_1": "file_name_1",
             "restorable_model_2": "file_name_2"}

        Get keys by calling ``<Class name>.get_restorable()``

        Args:
            model_dir: Save directory.
            network_map: Key is module name, value is saved name.
            version: Version number of the save to be loaded.
        """
        network_map = {} if network_map is None else network_map
        restore_map = {}
        for r in self._is_restorable:
            if r in network_map:
                restore_map[network_map[r]] = getattr(self, r)
            else:
                default_logger.warning(
                    'Load path for module "{}" is not specified, '
                    "module name is used.".format(r))
                restore_map[r] = getattr(self, r)
        prep_load_model(model_dir, restore_map, version)
예제 #3
0
파일: dqn_per.py 프로젝트: ikamensh/machin
 def __init__(self,
              qnet: Union[NeuralNetworkModule, nn.Module],
              qnet_target: Union[NeuralNetworkModule, nn.Module],
              optimizer: Callable,
              criterion: Callable,
              *_,
              lr_scheduler: Callable = None,
              lr_scheduler_args: Tuple[Tuple] = None,
              lr_scheduler_kwargs: Tuple[Dict] = None,
              batch_size: int = 100,
              epsilon_decay: float = 0.9999,
              update_rate: Union[float, None] = 0.005,
              update_steps: Union[int, None] = None,
              learning_rate: float = 0.001,
              discount: float = 0.99,
              gradient_max: float = np.inf,
              replay_size: int = 500000,
              replay_device: Union[str, t.device] = "cpu",
              replay_buffer: Buffer = None,
              visualize: bool = False,
              visualize_dir: str = "",
              **__):
     # DOC INHERITED
     super().__init__(
         qnet,
         qnet_target,
         optimizer,
         criterion,
         lr_scheduler=lr_scheduler,
         lr_scheduler_args=lr_scheduler_args,
         lr_scheduler_kwargs=lr_scheduler_kwargs,
         batch_size=batch_size,
         epsilon_decay=epsilon_decay,
         update_rate=update_rate,
         update_steps=update_steps,
         learning_rate=learning_rate,
         discount=discount,
         gradient_max=gradient_max,
         replay_size=replay_size,
         replay_device=replay_device,
         replay_buffer=(PrioritizedBuffer(replay_size, replay_device)
                        if replay_buffer is None else replay_buffer),
         mode="double",
         visualize=visualize,
         visualize_dir=visualize_dir,
     )
     # reduction must be None
     if not hasattr(self.criterion, "reduction"):
         raise RuntimeError("Criterion does not have the "
                            "'reduction' property, are you using a custom "
                            "criterion?")
     else:
         # A loss defined in ``torch.nn.modules.loss``
         if self.criterion.reduction != "none":
             default_logger.warning(
                 "The reduction property of criterion is not 'none', "
                 "automatically corrected.")
             self.criterion.reduction = "none"
예제 #4
0
 def __init__(self,
              actor: Union[NeuralNetworkModule, nn.Module],
              actor_target: Union[NeuralNetworkModule, nn.Module],
              critic: Union[NeuralNetworkModule, nn.Module],
              critic_target: Union[NeuralNetworkModule, nn.Module],
              optimizer: Callable,
              criterion,
              *_,
              lr_scheduler: Callable = None,
              lr_scheduler_args: Tuple[Tuple, Tuple] = None,
              lr_scheduler_kwargs: Tuple[Dict, Dict] = None,
              batch_size: int = 100,
              update_rate: float = 0.005,
              actor_learning_rate: float = 0.0005,
              critic_learning_rate: float = 0.001,
              discount: float = 0.99,
              gradient_max: float = np.inf,
              replay_size: int = 500000,
              replay_device: Union[str, t.device] = "cpu",
              replay_buffer: Buffer = None,
              visualize: bool = False,
              visualize_dir: str = "",
              **__):
     # DOC INHERITED
     super(DDPGPer, self).__init__(
         actor, actor_target, critic, critic_target, optimizer, criterion,
         lr_scheduler=lr_scheduler,
         lr_scheduler_args=lr_scheduler_args,
         lr_scheduler_kwargs=lr_scheduler_kwargs,
         batch_size=batch_size,
         update_rate=update_rate,
         actor_learning_rate=actor_learning_rate,
         critic_learning_rate=critic_learning_rate,
         discount=discount,
         gradient_max=gradient_max,
         replay_size=replay_size,
         replay_device=replay_device,
         replay_buffer=(PrioritizedBuffer(replay_size, replay_device)
                        if replay_buffer is None
                        else replay_buffer),
         visualize=visualize,
         visualize_dir=visualize_dir
     )
     # reduction must be None
     if not hasattr(self.criterion, "reduction"):
         raise RuntimeError("Criterion does not have the "
                            "'reduction' property")
     else:
         if hasattr(self.criterion, "reduction"):
             # A loss defined in ``torch.nn.modules.loss``
             if self.criterion.reduction != "none":
                 default_logger.warning(
                     "The reduction property of criterion is not 'none', "
                     "automatically corrected."
                 )
                 self.criterion.reduction = "none"
예제 #5
0
    def _jit_safe_call(model, *named_args):
        if (not hasattr(model, "input_device")
                or not hasattr(model, "output_device")):
            # try to automatically determine the input & output
            # device of the model
            model_type = type(model)
            device = determine_device(model)
            if len(device) > 1:
                raise RuntimeError(
                    "Failed to automatically determine i/o device "
                    "of your model: {}\n"
                    "Detected multiple devices: {}\n"
                    "You need to manually specify i/o device of "
                    "your model.\n"
                    "Wrap your model of type nn.Module with one "
                    "of: \n"
                    "1. static_module_wrapper "
                    "from machin.model.nets.base \n"
                    "1. dynamic_module_wrapper "
                    "from machin.model.nets.base \n"
                    "Or construct your own module & model with: \n"
                    "NeuralNetworkModule from machin.model.nets.base".format(
                        model_type, device))
            else:
                # assume that i/o devices are the same as parameter device
                # print a warning
                default_logger.warning(
                    "You have not specified the i/o device of"
                    "your model {}, automatically determined and"
                    " set to: {}\n"
                    "The framework is not responsible for any "
                    "un-matching device issues caused by this"
                    "operation.".format(model_type, device[0]))
                model = static_module_wrapper(model, device[0], device[0])
        input_device = model.input_device
        # set in __init__
        args = model.arg_spec.args[1:] + model.arg_spec.kwonlyargs
        if model.arg_spec.defaults is not None:
            args_with_defaults = args[-len(model.arg_spec.defaults):]
        else:
            args_with_defaults = []
        required_args = (set(args) - set(args_with_defaults) -
                         set(model.arg_spec.kwonlydefaults.keys() if model.
                             arg_spec.kwonlydefaults is not None else []))
        model_type = model.model_type
        # t.jit._fork does not support keyword args
        # fill arguments in by their positions.
        args_list = [None for _ in args]
        args_filled = [False for _ in args]

        for na in named_args:
            for k, v in na.items():
                if k in args:
                    if k not in args:
                        pass
                    args_filled[args.index(k)] = True
                    if t.is_tensor(v):
                        args_list[args.index(k)] = v.to(input_device)
                    else:
                        args_list[args.index(k)] = v

        if not all(args_filled):
            not_filled = [
                arg for filled, arg in zip(args_filled, args) if not filled
            ]
            required_not_filled = set(not_filled).intersection(required_args)
            if len(required_not_filled) > 0:
                raise RuntimeError("\n"
                                   "The signature of the forward function "
                                   "of Model {} is {}\n"
                                   "Missing required arguments: {}, "
                                   "check your storage functions.".format(
                                       model_type, required_args,
                                       required_not_filled))

        return t.jit._fork(model, *args_list)
예제 #6
0
    def update(self,
               update_value=True,
               update_policy=True,
               concatenate_samples=True,
               **__):
        # DOC INHERITED
        sum_value_loss = 0

        self.actor.train()
        self.critic.train()

        # sample a batch for actor training
        batch_size, (state, action,
                     advantage) = self.replay_buffer.sample_batch(
                         -1,
                         sample_method="all",
                         concatenate=concatenate_samples,
                         sample_attrs=["state", "action", "gae"],
                         additional_concat_attrs=["gae"],
                     )

        # normalize advantage
        if self.normalize_advantage:
            advantage = (advantage - advantage.mean()) / (advantage.std() +
                                                          1e-6)

        # Train actor
        # define two closures needed by fvp functions
        ___, fixed_action_log_prob, *_ = self._eval_act(state, action)
        fixed_action_log_prob = fixed_action_log_prob.view(batch_size,
                                                           1).detach()
        fixed_params = self.get_flat_params(self.actor)

        def actor_loss_func():
            ____, action_log_prob, *_ = self._eval_act(state, action)
            action_log_prob = action_log_prob.view(batch_size, 1)
            action_loss = -advantage.to(action_log_prob.device) * t.exp(
                action_log_prob - fixed_action_log_prob)
            return action_loss.mean()

        def actor_kl_func():
            state["params"] = fixed_params
            return safe_return(
                safe_call(self.actor, state, method="compare_kl"))

        act_policy_loss = actor_loss_func()

        if self.visualize:
            self.visualize_model(act_policy_loss, "actor", self.visualize_dir)

        # Update actor network
        if update_policy:

            def fvp(v):
                if self.hv_mode == "fim":
                    return self._fvp_fim(state, v, self.damping)
                else:
                    return self._fvp_direct(state, v, self.damping)

            loss_grad = self.get_flat_grad(
                act_policy_loss, list(self.actor.parameters())).detach()

            # usually 1e-15 is low enough
            if t.allclose(loss_grad, t.zeros_like(loss_grad), atol=1e-15):
                default_logger.warning("TRPO detects zero gradient.")

            step_dir = self._conjugate_gradients(
                fvp,
                -loss_grad,
                eps=self.conjugate_eps,
                iterations=self.conjugate_iterations,
                res_threshold=self.conjugate_res_threshold,
            )

            # Maximum step size mentioned in appendix C of the paper.
            beta = np.sqrt(2 * self.kl_max_delta /
                           step_dir.dot(fvp(step_dir)).item())

            full_step = step_dir * beta
            if not self._line_search(self.actor, actor_loss_func,
                                     actor_kl_func, full_step,
                                     self.kl_max_delta):
                default_logger.warning(
                    "Cannot find an update step to satisfy kl_max_delta, "
                    "consider increase line_search_backtracks")

        for _ in range(self.critic_update_times):
            # sample a batch
            batch_size, (state,
                         target_value) = self.replay_buffer.sample_batch(
                             self.batch_size,
                             sample_method="random_unique",
                             concatenate=concatenate_samples,
                             sample_attrs=["state", "value"],
                             additional_concat_attrs=["value"],
                         )
            # calculate value loss
            value = self._criticize(state)
            value_loss = (self.criterion(target_value.type_as(value), value) *
                          self.value_weight)

            if self.visualize:
                self.visualize_model(value_loss, "critic", self.visualize_dir)

            # Update critic network
            if update_value:
                self.critic.zero_grad()
                self._backward(value_loss)
                nn.utils.clip_grad_norm_(self.critic.parameters(),
                                         self.gradient_max)
                self.critic_optim.step()
            sum_value_loss += value_loss.item()

        self.replay_buffer.clear()
        self.actor.eval()
        self.critic.eval()
        return (
            act_policy_loss,
            sum_value_loss / self.critic_update_times,
        )
예제 #7
0
def safe_call(model, *named_args):
    """
    Call a model and discard unnecessary arguments. safe_call will automatically
    move tensors in named_args to the input device of the model

    Any input tensor in named_args must not be contained inside any container,
    such as list, dict, tuple, etc. Because they will be automatically moved
    to the input device of the specified model.

    Args:
        model: Model to be called, must be a wrapped nn.Module or an instance of
               NeuralNetworkModule.
        named_args: A dictionary of argument, key is argument's name, value is
                    argument's value.

    Returns:
        Whatever returned by your module. If result is not a tuple, always
        wrap results inside a tuple
    """
    org_model = None
    if isinstance(
            model,
        (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)):
        org_model = model
        model = model.module
    if (not hasattr(model, "input_device")
            or not hasattr(model, "output_device")):
        # try to automatically determine the input & output device of the model
        model_type = type(model)
        device = determine_device(model)
        if len(device) > 1:
            raise RuntimeError(
                "Failed to automatically determine i/o device "
                "of your model: {}\n"
                "Detected multiple devices: {}\n"
                "You need to manually specify i/o device of "
                "your model.\n"
                "Wrap your model of type nn.Module with one "
                "of: \n"
                "1. static_module_wrapper "
                "from machin.model.nets.base \n"
                "1. dynamic_module_wrapper "
                "from machin.model.nets.base \n"
                "Or construct your own module & model with: \n"
                "NeuralNetworkModule from machin.model.nets.base".format(
                    model_type, device))
        else:
            # assume that i/o devices are the same as parameter device
            # print a warning
            default_logger.warning(
                "You have not specified the i/o device of "
                "your model {}, automatically determined and"
                " set to: {}\n"
                "The framework is not responsible for any "
                "un-matching device issues caused by this "
                "operation.".format(model_type, device[0]))
            model = static_module_wrapper(model, device[0], device[0])

    input_device = model.input_device
    arg_spec = inspect.getfullargspec(model.forward)
    # exclude self in arg_spec.args
    args = arg_spec.args[1:] + arg_spec.kwonlyargs
    if arg_spec.defaults is not None:
        args_with_defaults = args[-len(arg_spec.defaults):]
    else:
        args_with_defaults = []
    required_args = (set(args) - set(args_with_defaults) -
                     set(arg_spec.kwonlydefaults.keys() if arg_spec.
                         kwonlydefaults is not None else []))
    args_dict = {}

    # fill in args
    for na in named_args:
        for k, v in na.items():
            if k in args:
                if torch.is_tensor(v):
                    args_dict[k] = v.to(input_device)
                else:
                    args_dict[k] = v

    # check for necessary args
    missing = required_args - set(args_dict.keys())
    if len(missing) > 0:
        raise RuntimeError("\n"
                           "The signature of the forward function of Model {} "
                           "is {}\n"
                           "Missing required arguments: {}, "
                           "check your storage functions.".format(
                               type(model), required_args, missing))

    if org_model is not None:
        result = org_model(**args_dict)
    else:
        result = model(**args_dict)

    if isinstance(result, tuple):
        return result
    else:
        return (result, )
예제 #8
0
파일: ars.py 프로젝트: ikamensh/machin
    def __init__(
        self,
        actor: Union[NeuralNetworkModule, nn.Module],
        optimizer: Callable,
        ars_group: RpcGroup,
        model_server: Tuple[PushPullModelServer],
        *_,
        lr_scheduler: Callable = None,
        lr_scheduler_args: Tuple[Tuple] = None,
        lr_scheduler_kwargs: Tuple[Dict] = None,
        learning_rate: float = 0.01,
        gradient_max: float = np.inf,
        noise_std_dev: float = 0.02,
        noise_size: int = 250000000,
        rollout_num: int = 32,
        used_rollout_num: int = 32,
        normalize_state: bool = True,
        noise_seed: int = 12345,
        sample_seed: int = 123,
        **__,
    ):
        """

        Note:
            The first process in `ars_group` will be the manager process.

        Args:
            actor: Actor network module.
            optimizer: Optimizer used to optimize ``actor`` and ``critic``.
            ars_group: Group of all processes using the ARS framework.
            model_server: Custom model sync server accessor for ``actor``.
            lr_scheduler: Learning rate scheduler of ``optimizer``.
            lr_scheduler_args: Arguments of the learning rate scheduler.
            lr_scheduler_kwargs: Keyword arguments of the learning
                rate scheduler.
            learning_rate: Learning rate of the optimizer, not compatible with
                ``lr_scheduler``.
            gradient_max: Maximum gradient.
            noise_std_dev: Standard deviation of the shared noise array.
            noise_size: Size of the shared noise array.
            rollout_num: Number of rollouts executed by workers in group.
            used_rollout_num: Number of used rollouts.
            normalize_state:  Whether to normalize the state seen by actor.
            noise_seed: Random seed used to generate noise.
            sample_seed: Based random seed used to sample noise.
        """
        assert rollout_num >= used_rollout_num
        self.grad_max = gradient_max
        self.rollout_num = rollout_num
        self.used_rollout_num = used_rollout_num
        self.normalize_state = normalize_state
        self.ars_group = ars_group

        # determine the number of rollouts(pair of actors with neg/pos delta)
        # assigned to current worker process
        w_num = len(ars_group.get_group_members())
        w_index = ars_group.get_group_members().index(ars_group.get_cur_name())
        segment_length = int(np.ceil(rollout_num / w_num))
        self.local_rollout_min = w_index * segment_length
        self.local_rollout_num = min(
            segment_length, rollout_num - self.local_rollout_min
        )

        self.actor = actor
        # `actor_with_delta` use rollout index and delta sign as key.
        # where rollout index is the absolute global index of rollout
        # and delta sign is true for positive, false for negative
        self.actor_with_delta = {}  # type: Dict[Tuple[int, bool], t.nn.Module]
        self.actor_optim = optimizer(self.actor.parameters(), lr=learning_rate)
        self.actor_model_server = model_server[0]

        # `filter` use state name as key
        # eg: "state_1"
        self.filter = {}  # type: Dict[str, MeanStdFilter]

        # `delta_idx` use rollout index as key
        # The inner dict use model parameter name as key, and starting
        # noise index in the noise array as value.
        self.delta_idx = {}  # type: Dict[int, Dict[str, int]]

        # `reward` use rollout index as key, the first list stores
        # rewards of model with negative noise delta, the second list
        # stores rewards of model with positive noise delta.
        self.reward = {}  # type: Dict[int, Tuple[List, List]]

        if lr_scheduler is not None:
            if lr_scheduler_args is None:
                lr_scheduler_args = ((),)
            if lr_scheduler_kwargs is None:
                lr_scheduler_kwargs = ({},)
            self.actor_lr_sch = lr_scheduler(
                self.actor_optim, *lr_scheduler_args[0], **lr_scheduler_kwargs[0],
            )

        # generate shared noise
        # estimate model parameter num first
        param_max_num = 0
        for param in actor.parameters():
            param_max_num = max(np.prod(np.array(param.shape)), param_max_num)
        if param_max_num * 10 > noise_size:
            default_logger.warning(
                "Maximum parameter size of your model is "
                "{}, which is more than 1/10 of your noise"
                "size {}, consider increasing noise_size.".format(
                    param_max_num, noise_size
                )
            )
        elif param_max_num >= noise_size:
            raise ValueError(
                "Noise size {} is too small compared to"
                "maximum parameter size {}!".format(noise_size, param_max_num)
            )

        # create shared noise array
        self.noise_array = t.tensor(
            np.random.RandomState(noise_seed).randn(noise_size).astype(np.float64)
            * noise_std_dev
        )

        # create a sampler for each parameter in each rollout model
        # key is model parameter name
        self.noise_sampler = {}  # type: Dict[int, Dict[str, SharedNoiseSampler]]
        param_num = len(list(actor.parameters()))
        for lrn in range(self.local_rollout_num):
            r_idx = lrn + self.local_rollout_min
            sampler = {}
            for p_idx, (name, param) in enumerate(actor.named_parameters()):
                # each model and its inner parameters use a different
                # sampling stream of the same noise array.
                sampler[name] = SharedNoiseSampler(
                    self.noise_array, sample_seed + r_idx * param_num + p_idx
                )
            self.noise_sampler[r_idx] = sampler

        # synchronize base actor parameters
        self._sync_actor()
        self._generate_parameter()
        self._reset_reward_dict()
        super().__init__()
예제 #9
0
파일: td3.py 프로젝트: mrshenli/machin
    def __init__(self,
                 actor: Union[NeuralNetworkModule, nn.Module],
                 actor_target: Union[NeuralNetworkModule, nn.Module],
                 critic: Union[NeuralNetworkModule, nn.Module],
                 critic_target: Union[NeuralNetworkModule, nn.Module],
                 critic2: Union[NeuralNetworkModule, nn.Module],
                 critic2_target: Union[NeuralNetworkModule, nn.Module],
                 optimizer: Callable,
                 criterion: Callable,
                 *_,
                 lr_scheduler: Callable = None,
                 lr_scheduler_args: Tuple[Tuple, Tuple, Tuple] = None,
                 lr_scheduler_kwargs: Tuple[Dict, Dict, Dict] = None,
                 batch_size: int = 100,
                 update_rate: float = 0.005,
                 actor_learning_rate: float = 0.0005,
                 critic_learning_rate: float = 0.001,
                 discount: float = 0.99,
                 gradient_max: float = np.inf,
                 replay_size: int = 500000,
                 replay_device: Union[str, t.device] = "cpu",
                 replay_buffer: Buffer = None,
                 policy_noise_func: Callable = None,
                 reward_func: Callable = None,
                 action_trans_func: Callable = None,
                 visualize: bool = False,
                 visualize_dir: str = "",
                 **__):
        """
        See Also:
            :class:`.DDPG`

        Args:
            actor: Actor network module.
            actor_target: Target actor network module.
            critic: Critic network module.
            critic_target: Target critic network module.
            critic2: The second critic network module.
            critic2_target: The second target critic network module.
            optimizer: Optimizer used to optimize ``actor``, ``critic``,
            criterion: Criterion used to evaluate the value loss.
            lr_scheduler: Learning rate scheduler of ``optimizer``.
            lr_scheduler_args: Arguments of the learning rate scheduler.
            lr_scheduler_kwargs: Keyword arguments of the learning
                rate scheduler.
            batch_size: Batch size used during training.
            update_rate: :math:`\\tau` used to update target networks.
                Target parameters are updated as:

                :math:`\\theta_t = \\theta * \\tau + \\theta_t * (1 - \\tau)`
            actor_learning_rate: Learning rate of the actor optimizer,
                not compatible with ``lr_scheduler``.
            critic_learning_rate: Learning rate of the critic optimizer,
                not compatible with ``lr_scheduler``.
            discount: :math:`\\gamma` used in the bellman function.
            replay_size: Replay buffer size. Not compatible with
                ``replay_buffer``.
            replay_device: Device where the replay buffer locates on, Not
                compatible with ``replay_buffer``.
            replay_buffer: Custom replay buffer.
            reward_func: Reward function used in training.
            action_trans_func: Action transform function, used to transform
                the raw output of your actor, by default it is:
                ``lambda act: {"action": act}``
            visualize: Whether visualize the network flow in the first pass.
            visualize_dir: Visualized graph save directory.
        """
        if lr_scheduler_args is None:
            lr_scheduler_args = ((), (), ())
        if lr_scheduler_kwargs is None:
            lr_scheduler_kwargs = ({}, {}, {})

        super(TD3, self).__init__(
            actor,
            actor_target,
            critic,
            critic_target,
            optimizer,
            criterion,
            lr_scheduler=lr_scheduler,
            lr_scheduler_args=(lr_scheduler_args[:2]
                               if lr_scheduler_args is not None else None),
            lr_scheduler_kwargs=(lr_scheduler_kwargs[:2]
                                 if lr_scheduler_kwargs is not None else None),
            batch_size=batch_size,
            update_rate=update_rate,
            actor_learning_rate=actor_learning_rate,
            critic_learning_rate=critic_learning_rate,
            discount=discount,
            gradient_max=gradient_max,
            replay_size=replay_size,
            replay_device=replay_device,
            replay_buffer=replay_buffer,
            reward_func=reward_func,
            action_trans_func=action_trans_func,
            visualize=visualize,
            visualize_dir=visualize_dir)
        self.critic2 = critic2
        self.critic2_target = critic2_target
        self.critic2_optim = optimizer(self.critic2.parameters(),
                                       lr=critic_learning_rate)

        # Make sure target and online networks have the same weight
        with t.no_grad():
            hard_update(self.critic2, self.critic2_target)

        if lr_scheduler is not None:
            self.critic2_lr_sch = lr_scheduler(self.critic2_optim,
                                               *lr_scheduler_args[2],
                                               **lr_scheduler_kwargs[2])

        if policy_noise_func is None:
            default_logger.warning("Policy noise function is None, "
                                   "no policy noise will be applied "
                                   "during update!")
        self.policy_noise_func = (TD3._policy_noise_function
                                  if policy_noise_func is None else
                                  policy_noise_func)