Beispiel #1
0
class SARSALambdaContinuous(TD):
    """
    Continuous version of SARSA(lambda) algorithm.

    """
    def __init__(self,
                 mdp_info,
                 policy,
                 approximator,
                 learning_rate,
                 lambda_coeff,
                 features,
                 approximator_params=None):
        """
        Constructor.

        Args:
            lambda_coeff (float): eligibility trace coefficient.

        """
        self._approximator_params = dict() if approximator_params is None else \
            approximator_params

        self.Q = Regressor(approximator, **self._approximator_params)
        self.e = np.zeros(self.Q.weights_size)
        self._lambda = lambda_coeff

        self._add_save_attr(_approximator_params='pickle',
                            Q='pickle',
                            _lambda='numpy',
                            e='numpy')

        super().__init__(mdp_info, policy, self.Q, learning_rate, features)

    def _update(self, state, action, reward, next_state, absorbing):
        phi_state = self.phi(state)
        q_current = self.Q.predict(phi_state, action)

        alpha = self.alpha(state, action)

        self.e = self.mdp_info.gamma * self._lambda * self.e + self.Q.diff(
            phi_state, action)

        self.next_action = self.draw_action(next_state)
        phi_next_state = self.phi(next_state)
        q_next = self.Q.predict(phi_next_state,
                                self.next_action) if not absorbing else 0.

        delta = reward + self.mdp_info.gamma * q_next - q_current

        theta = self.Q.get_weights()
        theta += alpha * delta * self.e
        self.Q.set_weights(theta)

    def episode_start(self):
        self.e = np.zeros(self.Q.weights_size)

        super().episode_start()
Beispiel #2
0
class DDPG(DeepAC):
    """
    Deep Deterministic Policy Gradient algorithm.
    "Continuous Control with Deep Reinforcement Learning".
    Lillicrap T. P. et al.. 2016.

    """
    def __init__(self,
                 mdp_info,
                 policy_class,
                 policy_params,
                 actor_params,
                 actor_optimizer,
                 critic_params,
                 batch_size,
                 initial_replay_size,
                 max_replay_size,
                 tau,
                 policy_delay=1,
                 critic_fit_params=None):
        """
        Constructor.

        Args:
            policy_class (Policy): class of the policy;
            policy_params (dict): parameters of the policy to build;
            actor_params (dict): parameters of the actor approximator to
                build;
            actor_optimizer (dict): parameters to specify the actor optimizer
                algorithm;
            critic_params (dict): parameters of the critic approximator to
                build;
            batch_size (int): the number of samples in a batch;
            initial_replay_size (int): the number of samples to collect before
                starting the learning;
            max_replay_size (int): the maximum number of samples in the replay
                memory;
            tau (float): value of coefficient for soft updates;
            policy_delay (int, 1): the number of updates of the critic after
                which an actor update is implemented;
            critic_fit_params (dict, None): parameters of the fitting algorithm
                of the critic approximator;

        """
        self._critic_fit_params = dict(
        ) if critic_fit_params is None else critic_fit_params

        self._batch_size = batch_size
        self._tau = tau
        self._policy_delay = policy_delay
        self._fit_count = 0

        self._replay_memory = ReplayMemory(initial_replay_size,
                                           max_replay_size)

        target_critic_params = deepcopy(critic_params)
        self._critic_approximator = Regressor(TorchApproximator,
                                              **critic_params)
        self._target_critic_approximator = Regressor(TorchApproximator,
                                                     **target_critic_params)

        target_actor_params = deepcopy(actor_params)
        self._actor_approximator = Regressor(TorchApproximator, **actor_params)
        self._target_actor_approximator = Regressor(TorchApproximator,
                                                    **target_actor_params)

        self._init_target(self._critic_approximator,
                          self._target_critic_approximator)
        self._init_target(self._actor_approximator,
                          self._target_actor_approximator)

        policy = policy_class(self._actor_approximator, **policy_params)

        policy_parameters = self._actor_approximator.model.network.parameters()

        super().__init__(mdp_info, policy, actor_optimizer, policy_parameters)

    def fit(self, dataset):
        self._replay_memory.add(dataset)
        if self._replay_memory.initialized:
            state, action, reward, next_state, absorbing, _ =\
                self._replay_memory.get(self._batch_size)

            q_next = self._next_q(next_state, absorbing)
            q = reward + self.mdp_info.gamma * q_next

            self._critic_approximator.fit(state, action, q,
                                          **self._critic_fit_params)

            if self._fit_count % self._policy_delay == 0:
                loss = self._loss(state)
                self._optimize_actor_parameters(loss)

            self._update_target(self._critic_approximator,
                                self._target_critic_approximator)
            self._update_target(self._actor_approximator,
                                self._target_actor_approximator)

            self._fit_count += 1

    def _loss(self, state):
        action = self._actor_approximator(state, output_tensor=True)
        q = self._critic_approximator(state, action, output_tensor=True)

        return -q.mean()

    def _next_q(self, next_state, absorbing):
        """
        Args:
            next_state (np.ndarray): the states where next action has to be
                evaluated;
            absorbing (np.ndarray): the absorbing flag for the states in
                ``next_state``.

        Returns:
            Action-values returned by the critic for ``next_state`` and the
            action returned by the actor.

        """
        a = self._target_actor_approximator(next_state)

        q = self._target_critic_approximator.predict(next_state, a)
        q *= 1 - absorbing

        return q
Beispiel #3
0
class SAC(DeepAC):
    """
    Soft Actor-Critic algorithm.
    "Soft Actor-Critic Algorithms and Applications".
    Haarnoja T. et al.. 2019.

    """
    def __init__(self, mdp_info, actor_mu_params, actor_sigma_params,
                 actor_optimizer, critic_params, batch_size,
                 initial_replay_size, max_replay_size, warmup_transitions, tau,
                 lr_alpha, target_entropy=None, critic_fit_params=None):
        """
        Constructor.

        Args:
            actor_mu_params (dict): parameters of the actor mean approximator
                to build;
            actor_sigma_params (dict): parameters of the actor sigm
                approximator to build;
            actor_optimizer (dict): parameters to specify the actor
                optimizer algorithm;
            critic_params (dict): parameters of the critic approximator to
                build;
            batch_size (int): the number of samples in a batch;
            initial_replay_size (int): the number of samples to collect before
                starting the learning;
            max_replay_size (int): the maximum number of samples in the replay
                memory;
            warmup_transitions (int): number of samples to accumulate in the
                replay memory to start the policy fitting;
            tau (float): value of coefficient for soft updates;
            lr_alpha (float): Learning rate for the entropy coefficient;
            target_entropy (float, None): target entropy for the policy, if
                None a default value is computed ;
            critic_fit_params (dict, None): parameters of the fitting algorithm
                of the critic approximator.

        """
        self._critic_fit_params = dict() if critic_fit_params is None else critic_fit_params

        self._batch_size = batch_size
        self._warmup_transitions = warmup_transitions
        self._tau = tau

        if target_entropy is None:
            self._target_entropy = -np.prod(mdp_info.action_space.shape).astype(np.float32)
        else:
            self._target_entropy = target_entropy

        self._replay_memory = ReplayMemory(initial_replay_size, max_replay_size)

        if 'n_models' in critic_params.keys():
            assert critic_params['n_models'] == 2
        else:
            critic_params['n_models'] = 2

        target_critic_params = deepcopy(critic_params)
        self._critic_approximator = Regressor(TorchApproximator,
                                              **critic_params)
        self._target_critic_approximator = Regressor(TorchApproximator,
                                                     **target_critic_params)

        actor_mu_approximator = Regressor(TorchApproximator,
                                          **actor_mu_params)
        actor_sigma_approximator = Regressor(TorchApproximator,
                                             **actor_sigma_params)

        policy = SACPolicy(actor_mu_approximator,
                           actor_sigma_approximator,
                           mdp_info.action_space.low,
                           mdp_info.action_space.high)

        self._init_target(self._critic_approximator,
                          self._target_critic_approximator)

        self._log_alpha = torch.tensor(0., dtype=torch.float32)

        if policy.use_cuda:
            self._log_alpha = self._log_alpha.cuda().requires_grad_()
        else:
            self._log_alpha.requires_grad_()

        self._alpha_optim = optim.Adam([self._log_alpha], lr=lr_alpha)

        policy_parameters = chain(actor_mu_approximator.model.network.parameters(),
                                  actor_sigma_approximator.model.network.parameters())

        self._add_save_attr(
            _critic_fit_params='pickle',
            _batch_size='numpy',
            _warmup_transitions='numpy',
            _tau='numpy',
            _target_entropy='numpy',
            _replay_memory='pickle',
            _critic_approximator='pickle',
            _target_critic_approximator='pickle',
            _log_alpha='pickle',
            _alpha_optim='pickle'
        )

        super().__init__(mdp_info, policy, actor_optimizer, policy_parameters)

    def fit(self, dataset):
        self._replay_memory.add(dataset)
        if self._replay_memory.initialized:
            state, action, reward, next_state, absorbing, _ = \
                self._replay_memory.get(self._batch_size)

            if self._replay_memory.size > self._warmup_transitions:
                action_new, log_prob = self.policy.compute_action_and_log_prob_t(state)
                loss = self._loss(state, action_new, log_prob)
                self._optimize_actor_parameters(loss)
                self._update_alpha(log_prob.detach())

            q_next = self._next_q(next_state, absorbing)
            q = reward + self.mdp_info.gamma * q_next

            self._critic_approximator.fit(state, action, q,
                                          **self._critic_fit_params)

            self._update_target(self._critic_approximator,
                                self._target_critic_approximator)

    def _loss(self, state, action_new, log_prob):
        q_0 = self._critic_approximator(state, action_new,
                                        output_tensor=True, idx=0)
        q_1 = self._critic_approximator(state, action_new,
                                        output_tensor=True, idx=1)

        q = torch.min(q_0, q_1)

        return (self._alpha * log_prob - q).mean()

    def _update_alpha(self, log_prob):
        alpha_loss = - (self._log_alpha * (log_prob + self._target_entropy)).mean()
        self._alpha_optim.zero_grad()
        alpha_loss.backward()
        self._alpha_optim.step()

    def _next_q(self, next_state, absorbing):
        """
        Args:
            next_state (np.ndarray): the states where next action has to be
                evaluated;
            absorbing (np.ndarray): the absorbing flag for the states in
                ``next_state``.

        Returns:
            Action-values returned by the critic for ``next_state`` and the
            action returned by the actor.

        """
        a, log_prob_next = self.policy.compute_action_and_log_prob(next_state)

        q = self._target_critic_approximator.predict(
            next_state, a, prediction='min') - self._alpha_np * log_prob_next
        q *= 1 - absorbing

        return q

    def _post_load(self):
        if self._optimizer is not None:
            self._parameters = list(
                chain(self.policy._mu_approximator.model.network.parameters(),
                      self.policy._sigma_approximator.model.network.parameters()
                )
            )

    @property
    def _alpha(self):
        return self._log_alpha.exp()

    @property
    def _alpha_np(self):
        return self._alpha.detach().cpu().numpy()
Beispiel #4
0
import numpy as np
from matplotlib import pyplot as plt

from mushroom_rl.approximators import Regressor
from mushroom_rl.approximators.parametric import LinearApproximator

x = np.arange(10).reshape(-1, 1)

intercept = 10
noise = np.random.randn(10, 1) * 1
y = 2 * x + intercept + noise

phi = np.concatenate((np.ones(10).reshape(-1, 1), x), axis=1)

regressor = Regressor(LinearApproximator,
                      input_shape=(2, ),
                      output_shape=(1, ))

regressor.fit(phi, y)

print('Weights: ' + str(regressor.get_weights()))
print('Gradient: ' + str(regressor.diff(np.array([[5.]]))))

plt.scatter(x, y)
plt.plot(x, regressor.predict(phi))
plt.show()
Beispiel #5
0
class DDPG(DeepAC):
    def __init__(self,
                 mdp_info,
                 policy_class,
                 policy_params,
                 actor_params,
                 actor_optimizer,
                 critic_params,
                 batch_size,
                 replay_memory,
                 tau,
                 optimization_steps,
                 comm,
                 policy_delay=1,
                 critic_fit_params=None):
        self._critic_fit_params = dict(
        ) if critic_fit_params is None else critic_fit_params

        self._batch_size = batch_size
        self._tau = tau
        self._optimization_steps = optimization_steps
        self._comm = comm
        self._policy_delay = policy_delay
        self._fit_count = 0

        if comm.Get_rank() == 0:
            self._replay_memory = replay_memory

        target_critic_params = deepcopy(critic_params)
        self._critic_approximator = Regressor(TorchApproximator,
                                              **critic_params)
        self._target_critic_approximator = Regressor(TorchApproximator,
                                                     **target_critic_params)

        target_actor_params = deepcopy(actor_params)
        self._actor_approximator = Regressor(TorchApproximator, **actor_params)
        self._target_actor_approximator = Regressor(TorchApproximator,
                                                    **target_actor_params)

        self._init_target(self._critic_approximator,
                          self._target_critic_approximator)
        self._init_target(self._actor_approximator,
                          self._target_actor_approximator)

        policy = policy_class(self._actor_approximator, **policy_params)

        policy_parameters = self._actor_approximator.model.network.parameters()

        self._add_save_attr(_critic_fit_params='pickle',
                            _batch_size='numpy',
                            _tau='numpy',
                            _policy_delay='numpy',
                            _fit_count='numpy',
                            _replay_memory='pickle',
                            _critic_approximator='pickle',
                            _target_critic_approximator='pickle',
                            _actor_approximator='pickle',
                            _target_actor_approximator='pickle')

        super().__init__(mdp_info, policy, actor_optimizer, policy_parameters)

    def fit(self, dataset):
        if self._comm.Get_rank() == 0:
            for i in range(1, self._comm.Get_size()):
                dataset += self._comm.recv(source=i)
            self._replay_memory.add(dataset)

            self._comm.Barrier()
        else:
            self._comm.send(dataset, dest=0)
            self._comm.Barrier()

        for _ in range(self._optimization_steps):
            if self._comm.Get_rank() == 0:
                state, action, reward, next_state =\
                    self._replay_memory.get(self._batch_size * self._comm.Get_size())
            else:
                state = None
                action = None
                reward = None
                next_state = None
            state, action, reward, next_state = self._comm.bcast(
                [state, action, reward, next_state], root=0)

            start = self._batch_size * self._comm.Get_rank()
            stop = start + self._batch_size
            state = state[start:stop]
            action = action[start:stop]
            reward = reward[start:stop]
            next_state = next_state[start:stop]

            q_next = self._next_q(next_state)
            q = reward + self.mdp_info.gamma * q_next
            q = np.clip(q, -1 / (1 - self.mdp_info.gamma), 0)

            self._critic_approximator.fit(state, action, q,
                                          **self._critic_fit_params)

            if self._fit_count % self._policy_delay == 0:
                loss = self._loss(state)
                self._optimize_actor_parameters(loss)

            self._fit_count += 1

        self._update_target(self._critic_approximator,
                            self._target_critic_approximator)
        self._update_target(self._actor_approximator,
                            self._target_actor_approximator)

    def _loss(self, state):
        action = self._actor_approximator(state,
                                          output_tensor=True,
                                          scaled=False)
        q = self._critic_approximator(state, action, output_tensor=True)

        return -q.mean() + (action**2).mean()

    def _next_q(self, next_state):
        a = self._target_actor_approximator(next_state)

        q = self._target_critic_approximator.predict(next_state, a)

        return q

    def _post_load(self):
        if self._optimizer is not None:
            self._parameters = list(
                self._actor_approximator.model.network.parameters())

    def draw_action(self, state):
        state = np.append(state['observation'], state['desired_goal'])
        if self._comm.Get_rank() == 0:
            mu = self._replay_memory._mu
            sigma2 = self._replay_memory._sigma2
        else:
            mu = None
            sigma2 = None
        mu, sigma2 = self._comm.bcast([mu, sigma2], root=0)

        if not np.any(sigma2 == 0):
            state = normalize_and_clip(state, mu, sigma2)

        return self.policy.draw_action(state)
Beispiel #6
0
class OptionSAC(DeepOptionAC):

    def __init__(self, mdp_info, actor_mu_params, actor_sigma_params,
                 actor_optimizer, critic_params, batch_size,
                 initial_replay_size, max_replay_size, warmup_transitions, tau,
                 lr_alpha, rarhmm: rARHMM, target_entropy=None, critic_fit_params=None):

        """
        Constructor.

        Args:
            actor_mu_params (dict): parameters of the actor mean approximator
                to build;
            actor_sigma_params (dict): parameters of the actor sigm
                approximator to build;
            actor_optimizer (dict): parameters to specify the actor
                optimizer algorithm;
            critic_params (dict): parameters of the critic approximator to
                build;
            batch_size (int): the number of samples in a batch;
            initial_replay_size (int): the number of samples to collect before
                starting the learning;
            max_replay_size (int): the maximum number of samples in the replay
                memory;
            warmup_transitions (int): number of samples to accumulate in the
                replay memory to start the policy fitting;
            tau (float): value of coefficient for soft updates;
            lr_alpha (float): Learning rate for the entropy coefficient;
            target_entropy (float, None): target entropy for the policy, if
                None a default value is computed ;
            critic_fit_params (dict, None): parameters of the fitting algorithm
                of the critic approximator.

        """
        self.rarhmm = rarhmm
        self._critic_fit_params = dict() if critic_fit_params is None else critic_fit_params

        self._batch_size = batch_size
        self._warmup_transitions = warmup_transitions
        self._tau = tau

        if target_entropy is None:
            self._target_entropy = -np.prod(mdp_info.action_space.shape).astype(np.float32)
        else:
            self._target_entropy = target_entropy

        self._replay_memory = OptionReplayMemory(initial_replay_size, max_replay_size)

        if 'n_models' in critic_params.keys():
            assert critic_params['n_models'] == 2
        else:
            critic_params['n_models'] = 2

        target_critic_params = deepcopy(critic_params)
        self._critic_approximator = Regressor(TorchApproximator,
                                              **critic_params)
        self._target_critic_approximator = Regressor(TorchApproximator,
                                                     **target_critic_params)

        actor_mu_approximator = [Regressor(TorchApproximator,
                                          **actor_mu_params)
                                 for _ in range(rarhmm.nb_states)]
        actor_sigma_approximator = [Regressor(TorchApproximator,
                                             **actor_sigma_params)
                                    for _ in range(rarhmm.nb_states)]

        policy = [SACPolicy(actor_mu_approximator[o],
                           actor_sigma_approximator[o],
                           mdp_info.action_space.low,
                           mdp_info.action_space.high) for o in range(rarhmm.nb_states)]

        self._init_target(self._critic_approximator,
                          self._target_critic_approximator)

        self._log_alpha = torch.tensor(0., dtype=torch.float32)

        if policy[0].use_cuda:
            self._log_alpha = self._log_alpha.cuda().requires_grad_()
        else:
            self._log_alpha.requires_grad_()

        self._alpha_optim = optim.Adam([self._log_alpha], lr=lr_alpha)

        policy_parameters = [chain(actor_mu_approximator[o].model.network.parameters(),
                                  actor_sigma_approximator[o].model.network.parameters())
                             for o in range(rarhmm.nb_states)]

        self._add_save_attr(
            _critic_fit_params='pickle',
            _batch_size='numpy',
            _warmup_transitions='numpy',
            _tau='numpy',
            _target_entropy='numpy',
            _replay_memory='pickle',
            _critic_approximator='pickle',
            _target_critic_approximator='pickle',
            _log_alpha='pickle',
            _alpha_optim='pickle'
        )

        super().__init__(mdp_info, policy, actor_optimizer, policy_parameters, rarhmm.nb_states)

    def fit(self, dataset):
        self._replay_memory.add(dataset)
        if self._replay_memory.initialized:
            state, action, reward, next_state, absorbing, _, option, option_weight = \
                self._replay_memory.get(self._batch_size)

            if self._replay_memory.size > self._warmup_transitions:
                action_new = torch.empty((0, self.mdp_info.action_space.shape[0]))
                log_prob = torch.empty((0))
                for o in range(self.rarhmm.nb_states):
                    selection = (o == option)
                    _action_new, _log_prob = self.policy[o].compute_action_and_log_prob_t(state[selection])
                    action_new = torch.cat([action_new, _action_new]) # TODO: check if stacking is correct
                    log_prob = torch.cat([log_prob, _log_prob])
                    loss = self._loss(state[selection], _action_new, _log_prob)
                    self._optimize_actor_parameters(loss, o) # TODO: Look at global loss
                self._update_alpha(log_prob.detach())

            q_next = self._next_q(next_state, absorbing, option)
            q = reward + self.mdp_info.gamma * q_next

            self._critic_approximator.fit(state, action, q,
                                          **self._critic_fit_params)

            self._update_target(self._critic_approximator,
                                self._target_critic_approximator)

    def _loss(self, state, action_new, log_prob):
        q_0 = self._critic_approximator(state, action_new,
                                        output_tensor=True, idx=0)
        q_1 = self._critic_approximator(state, action_new,
                                        output_tensor=True, idx=1)

        q = torch.min(q_0, q_1)

        return (self._alpha * log_prob - q).mean()

    def _update_alpha(self, log_prob):
        alpha_loss = - (self._log_alpha * (log_prob + self._target_entropy)).mean()
        self._alpha_optim.zero_grad()
        alpha_loss.backward()
        self._alpha_optim.step()

    def _next_q(self, next_state, absorbing, option=None):
        """
        Args:
            next_state (np.ndarray): the states where next action has to be
                evaluated;
            absorbing (np.ndarray): the absorbing flag for the states in
                ``next_state``.

        Returns:
            Action-values returned by the critic for ``next_state`` and the
            action returned by the actor.

        """
        # TODO: submit options batch

        # a = np.empty((0, self.mdp_info.action_space.shape[0]))
        # log_prob_next = np.empty((0))
        # for o in range(self.n_options):
        #     selection = (o == option)
        #     _a, _log_prob_next = self.policy[o].compute_action_and_log_prob(next_state[selection])
        #     a = np.vstack([a, _a])
        #     log_prob_next = np.hstack([log_prob_next, _log_prob_next])

        a = 0
        log_prob_next = 0
        for o in range(self.n_options):
            _a, _log_prob_next = self.policy[o].compute_action_and_log_prob(next_state)
            a += _a
            log_prob_next += _log_prob_next
        a /= self.n_options
        log_prob_next /= self.n_options

        q = self._target_critic_approximator.predict(
            next_state, a, prediction='min') - self._alpha_np * log_prob_next
        q *= 1 - absorbing

        return q

    def _post_load(self):
        if self._optimizer is not None:
            self._parameters = list(
                chain(self.policy._mu_approximator.model.network.parameters(),
                      self.policy._sigma_approximator.model.network.parameters()
                      )
            )

    @property
    def _alpha(self):
        return self._log_alpha.exp()

    @property
    def _alpha_np(self):
        return self._alpha.detach().cpu().numpy()
Beispiel #7
0
class TrueOnlineSARSALambda(TD):
    """
    True Online SARSA(lambda) with linear function approximation.
    "True Online TD(lambda)". Seijen H. V. et al.. 2014.

    """
    def __init__(self, mdp_info, policy, learning_rate, lambda_coeff,
                 features, approximator_params=None):
        """
        Constructor.

        Args:
            lambda_coeff (float): eligibility trace coefficient.

        """
        self._approximator_params = dict() if approximator_params is None else \
            approximator_params

        self.Q = Regressor(LinearApproximator, **self._approximator_params)
        self.e = np.zeros(self.Q.weights_size)
        self._lambda = lambda_coeff
        self._q_old = None

        self._add_save_attr(
            _approximator_params='pickle',
            Q='pickle',
            _q_old='pickle',
            _lambda='numpy',
            e='numpy'
        )

        super().__init__(mdp_info, policy, self.Q, learning_rate, features)

    def _update(self, state, action, reward, next_state, absorbing):
        phi_state = self.phi(state)
        phi_state_action = get_action_features(phi_state, action,
                                               self.mdp_info.action_space.n)
        q_current = self.Q.predict(phi_state, action)

        if self._q_old is None:
            self._q_old = q_current

        alpha = self.alpha(state, action)

        e_phi = self.e.dot(phi_state_action)
        self.e = self.mdp_info.gamma * self._lambda * self.e + alpha * (
            1. - self.mdp_info.gamma * self._lambda * e_phi) * phi_state_action

        self.next_action = self.draw_action(next_state)
        phi_next_state = self.phi(next_state)
        q_next = self.Q.predict(phi_next_state,
                                self.next_action) if not absorbing else 0.

        delta = reward + self.mdp_info.gamma * q_next - self._q_old

        theta = self.Q.get_weights()
        theta += delta * self.e + alpha * (
            self._q_old - q_current) * phi_state_action
        self.Q.set_weights(theta)

        self._q_old = q_next

    def episode_start(self):
        self._q_old = None
        self.e = np.zeros(self.Q.weights_size)

        super().episode_start()