예제 #1
0
class PolicyGradEnt():
    """
    Implements an RL agent with policy gradient method.
    
    Notes
    -----
    GPU implementation is just sketched; it works but it's slower than with CPU.
    """
    def __init__(self,
                 observation_space,
                 action_space,
                 lr,
                 gamma,
                 H,
                 discrete=True,
                 project_dim=4,
                 device='cpu'):
        """
        Parameters
        ----------
        observation_space: int
            Number of flattened entries of the state
        action_space: int
            Number of (discrete) possible actions to take
        """

        self.gamma = gamma
        self.lr = lr
        self.H = H  # entropy coeff

        self.n_actions = action_space
        self.discrete = discrete
        if self.discrete:
            self.net = Actor(observation_space, action_space, discrete,
                             project_dim)
        else:
            self.net = Actor(observation_space, action_space, discrete)
        self.optim = torch.optim.Adam(self.net.parameters(), lr=self.lr)

        self.device = device
        self.net.to(self.device)  # move network to device

    def get_action(self, state, return_log=False):
        log_probs = self.forward(state)
        dist = torch.exp(log_probs)

        probs = Categorical(dist)
        action = probs.sample().item()

        if return_log:
            return action, log_probs.view(-1)[action], dist
        else:
            return action

    def forward(self, state):
        if self.discrete:
            state = torch.from_numpy(state).to(self.device)
        else:
            state = torch.from_numpy(state).float().unsqueeze(0).to(
                self.device)
        return self.net(state)

    def update(self, rewards, log_probs, distributions):

        ### Compute MC discounted returns ###

        Gamma = np.array([self.gamma**i for i in range(rewards.shape[0])])
        # reverse everything to use cumsum in right order, then reverse again
        Gt = np.cumsum(rewards[::-1] * Gamma[::-1])[::-1]
        # Rescale so that present reward is never discounted
        discounted_rewards = Gt / Gamma

        dr = torch.tensor(discounted_rewards).to(self.device)
        dr = (dr - dr.mean()) / dr.std()

        policy_gradient = []
        for log_prob, Gt in zip(log_probs, dr):
            policy_gradient.append(
                -log_prob * Gt)  # "-" for minimization instead of maximization

        distributions = torch.stack(distributions).squeeze()  # shape = (T,2)
        # Compute negative entropy (no - in front)
        entropy = torch.sum(distributions * torch.log(distributions),
                            axis=1).sum()
        policy_grad = torch.stack(policy_gradient).sum()
        loss = policy_grad + self.H * entropy

        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        return policy_grad.item()
예제 #2
0
class A2C():
    """
    Advantage Actor-Critic RL agent. 
    
    Notes
    -----
    * GPU implementation is still work in progress.
    * Always uses 2 separate networks for the critic,one that learns from new experience 
      (student/critic) and the other one (critic_target/teacher)that is more conservative 
      and whose weights are updated through an exponential moving average of the weights 
      of the critic, i.e.
          target.params = (1-tau)*target.params + tau* critic.params
    * In the case of Monte Carlo estimation the critic_target is never used
    * Possible to use twin networks for the critic and the critic target for improved 
      stability. Critic target is used for updates of both the actor and the critic and
      its output is the minimum between the predictions of its two internal networks.
      
    """
    def __init__(self,
                 observation_space,
                 action_space,
                 lr,
                 gamma,
                 TD=True,
                 discrete=False,
                 project_dim=8,
                 hiddens=[64, 32],
                 twin=False,
                 tau=1.,
                 n_steps=1,
                 device='cpu',
                 debug=False):
        """
        Parameters
        ----------
        observation_space: int
            Number of flattened entries of the state
        action_space: int
            Number of (discrete) possible actions to take
        lr: float in [0,1]
            Learning rate
        gamma: float in [0,1]
            Discount factor
        TD: bool (default=True)
            If True, uses Temporal Difference for the critic's estimates
            Otherwise uses Monte Carlo estimation
        discrete: bool (default=False)
            If True, adds an embedding layer both in the actor 
            and the critic networks before processing the state.
            Should be used if the state is a simple integer in [0, observation_space -1]
        project_dim: int (default=8)
            Number of dimensions of the embedding space (e.g. number of dimensions of
            embedding(state) ). Higher dimensions are more expressive.
        hiddens: list of int (default = [64,32])
            List containing the number of neurons of each linear hidden layer.
            Same architecture is considered for the actor and the critic, except from the 
            output layer, than in one case has the dimension of the action space and a LogSoftmax
            activation, in the other outputs a scalar (state value)
        twin: bool (default=False)
            Enables twin networks both for critic and critic_target
        tau: float in [0,1] (default = 1.)
            Regulates how fast the critic_target gets updates, i.e. what percentage of the weights
            inherits from the critic. If tau=1., critic and critic_target are identical 
            at every step, if tau=0. critic_target is unchangable. 
            As a default this feature is disabled setting tau = 1, but if one wants to use it a good
            empirical value is 0.005.
         n_steps: int (default=1)
             Number of steps considered in TD update
        device: str in {'cpu','cuda'} (default='cpu')
            Implemented, but GPU slower than CPU because it's difficult to optimize a RL agent without
            a replay buffer, that can be used only in off-policy algorithms.
        """

        self.gamma = gamma
        self.lr = lr

        self.n_actions = action_space
        self.discrete = discrete
        self.TD = TD
        self.twin = twin
        self.tau = tau
        self.n_steps = n_steps

        self.actor = Actor(observation_space,
                           action_space,
                           discrete,
                           project_dim,
                           hiddens=hiddens)
        self.critic = Critic(observation_space,
                             discrete,
                             project_dim,
                             twin,
                             hiddens=hiddens)

        if self.TD:
            self.critic_trg = Critic(observation_space,
                                     discrete,
                                     project_dim,
                                     twin,
                                     target=True,
                                     hiddens=hiddens)

            # Init critic target identical to critic
            for trg_params, params in zip(self.critic_trg.parameters(),
                                          self.critic.parameters()):
                trg_params.data.copy_(params.data)

        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=lr)

        self.device = device
        self.actor.to(self.device)
        self.critic.to(self.device)
        if self.TD:
            self.critic_trg.to(self.device)

        if debug:
            print("=" * 10 + " A2C HyperParameters " + "=" * 10)
            print("Discount factor: ", self.gamma)
            print("Learning rate: ", self.lr)
            print("Action space: ", self.n_actions)
            print("Discrete state space: ", self.discrete)
            print("Temporal Difference learning: ", self.TD)
            if self.TD:
                print("Number of TD steps: ", self.n_steps)
            print("Twin networks: ", self.twin)
            print("Update critic target factor: ", self.tau)
            print("Device used: ", self.device)
            print("\n\n" + "=" * 10 + " A2C Architecture " + "=" * 10)
            print("Actor architecture: \n", self.actor)
            print("Critic architecture: \n", self.critic)
            print("Critic target architecture: ")
            if self.TD:
                print(self.critic_trg)
            else:
                print("Not used")

    def get_action(self, state, return_log=False):
        log_probs = self.forward(state)
        dist = torch.exp(log_probs)
        probs = Categorical(dist)
        action = probs.sample().item()
        if return_log:
            return action, log_probs.view(-1)[action]
        else:
            return action

    def forward(self, state):
        """
        Makes a tensor out of a numpy array state and then forward
        it with the actor network.
        
        Parameters
        ----------
        state:
            If self.discrete is True state.shape = (episode_len,)
            Otherwise state.shape = (episode_len, observation_space)
        """
        if self.discrete:
            state = torch.from_numpy(state).to(self.device)
        else:
            state = torch.from_numpy(state).float().unsqueeze(0).to(
                self.device)
        log_probs = self.actor(state)
        return log_probs

    def update(self, *args):
        if self.TD:
            critic_loss, actor_loss = self.update_TD(*args)
        else:
            critic_loss, actor_loss = self.update_MC(*args)

        return critic_loss, actor_loss

    def update_TD(self, rewards, log_probs, states, done, bootstrap=None):

        ### Compute n-steps rewards, states, discount factors and done mask ###

        n_step_rewards = self.compute_n_step_rewards(rewards)
        if debug:
            print("n_step_rewards.shape: ", n_step_rewards.shape)
            print("rewards.shape: ", rewards.shape)
            print("n_step_rewards: ", n_step_rewards)
            print("rewards: ", rewards)

        if bootstrap is not None:
            done[bootstrap] = False
        if debug:
            print("done.shape: (before n_steps)", done.shape)
            print("done: (before n_steps)", done)

        if self.discrete:
            old_states = torch.tensor(states[:-1]).to(self.device)

            new_states, Gamma_V, done = self.compute_n_step_states(
                states, done)
            new_states = torch.tensor(new_states).to(self.device)

        else:
            old_states = torch.tensor(states[:, :-1]).float().to(self.device)

            new_states, Gamma_V, done = self.compute_n_step_states(
                states[0], done)
            new_states = torch.tensor(new_states).float().unsqueeze(0).to(
                self.device)

        if debug:
            print("done.shape: (after n_steps)", done.shape)
            print("Gamma_V.shape: ", Gamma_V.shape)
            print("done: (after n_steps)", done)
            print("Gamma_V: ", Gamma_V)
            print("old_states.shape: ", old_states.shape)
            print("new_states.shape: ", new_states.shape)

        ### Wrap variables into tensors ###

        done = torch.LongTensor(done.astype(int)).to(self.device)
        log_probs = torch.stack(log_probs).to(self.device)
        n_step_rewards = torch.tensor(n_step_rewards).float().to(self.device)
        Gamma_V = torch.tensor(Gamma_V).float().to(self.device)

        ### Update critic and then actor ###
        critic_loss = self.update_critic_TD(n_step_rewards, new_states,
                                            old_states, done, Gamma_V)
        actor_loss = self.update_actor_TD(n_step_rewards, log_probs,
                                          new_states, old_states, done,
                                          Gamma_V)

        return critic_loss, actor_loss

    def update_critic_TD(self, n_step_rewards, new_states, old_states, done,
                         Gamma_V):

        # Compute loss

        with torch.no_grad():
            V_trg = self.critic_trg(new_states).squeeze()
            if debug:
                print("V_trg.shape (after critic): ", V_trg.shape)
            V_trg = (1 - done) * Gamma_V * V_trg + n_step_rewards
            if debug:
                print("V_trg.shape (after sum): ", V_trg.shape)
            V_trg = V_trg.squeeze()
            if debug:
                print("V_trg.shape (after squeeze): ", V_trg.shape)

        if self.twin:
            V1, V2 = self.critic(old_states)
            loss1 = 0.5 * F.mse_loss(V1.squeeze(), V_trg)
            loss2 = 0.5 * F.mse_loss(V2.squeeze(), V_trg)
            loss = loss1 + loss2
        else:
            V = self.critic(old_states).squeeze()
            loss = F.mse_loss(V, V_trg)

        # Backpropagate and update

        self.critic_optim.zero_grad()
        loss.backward()
        self.critic_optim.step()

        # Update critic_target: (1-tau)*old + tau*new

        for trg_params, params in zip(self.critic_trg.parameters(),
                                      self.critic.parameters()):
            trg_params.data.copy_((1. - self.tau) * trg_params.data +
                                  self.tau * params.data)

        return loss.item()

    def update_actor_TD(self, n_step_rewards, log_probs, new_states,
                        old_states, done, Gamma_V):

        # Compute gradient

        if self.twin:
            V1, V2 = self.critic(old_states)
            V_pred = torch.min(V1.squeeze(), V2.squeeze())
            V1_new, V2_new = self.critic(new_states)
            V_new = torch.min(V1_new.squeeze(), V2_new.squeeze())
            V_trg = (1 - done) * Gamma_V * V_new + n_step_rewards
        else:
            V_pred = self.critic(old_states).squeeze()
            V_trg = (1 - done) * Gamma_V * self.critic(
                new_states).squeeze() + n_step_rewards

        A = V_trg - V_pred
        policy_gradient = -log_probs * A
        if debug:
            print("V_trg.shape: ", V_trg.shape)
            print("V_pred.shape: ", V_pred.shape)
            print("A.shape: ", A.shape)
            print("policy_gradient.shape: ", policy_gradient.shape)
        policy_grad = torch.sum(policy_gradient)

        # Backpropagate and update

        self.actor_optim.zero_grad()
        policy_grad.backward()
        self.actor_optim.step()

        return policy_grad.item()

    def compute_n_step_rewards(self, rewards):
        """
        Computes n-steps discounted reward padding with zeros the last elements of the trajectory.
        This means that the rewards considered are AT MOST n, but can be less for the last n-1 elements.
        """
        T = len(rewards)

        # concatenate n_steps zeros to the rewards -> they do not change the cumsum
        r = np.concatenate((rewards, [0 for _ in range(self.n_steps)]))

        Gamma = np.array([self.gamma**i for i in range(r.shape[0])])

        # reverse everything to use cumsum in right order, then reverse again
        Gt = np.cumsum(r[::-1] * Gamma[::-1])[::-1]

        G_nstep = Gt[:T] - Gt[
            self.n_steps:]  # compute n-steps discounted return

        Gamma = Gamma[:T]

        assert len(
            G_nstep) == T, "Something went wrong computing n-steps reward"

        n_steps_r = G_nstep / Gamma

        return n_steps_r

    def compute_n_step_states(self, states, done):
        """
        Computes n-steps target states (to be used by the critic as target values together with the
        n-steps discounted reward). For last n-1 elements the target state is the last one available.
        Adjusts also the `done` mask used for disabling the bootstrapping in the case of terminal states
        and returns Gamma_V, that are the discount factors for the target state-values, since they are 
        n-steps away (except for the last n-1 states, whose discount is adjusted accordingly).
        
        Return
        ------
        new_states, Gamma_V, done: arrays with first dimension = len(states)-1
        """

        # Compute indexes for (at most) n-step away states

        n_step_idx = np.arange(len(states) - 1) + self.n_steps
        diff = n_step_idx - len(states) + 1
        mask = (diff > 0)
        n_step_idx[mask] = len(states) - 1

        # Compute new states

        new_states = states[n_step_idx]

        # Compute discount factors

        pw = np.array([self.n_steps for _ in range(len(new_states))])
        pw[mask] = self.n_steps - diff[mask]
        Gamma_V = self.gamma**pw

        # Adjust done mask

        mask = (diff >= 0)
        done[mask] = done[-1]

        return new_states, Gamma_V, done

    def update_MC(self, rewards, log_probs, states, done, bootstrap=None):

        ### Compute MC discounted returns ###

        if bootstrap is not None:

            if bootstrap[-1] == True:

                last_state = torch.tensor(states[0, -1, :]).float().to(
                    self.device).view(1, -1)

                if self.twin:
                    V1, V2 = self.critic(last_state)
                    V_bootstrap = torch.min(V1,
                                            V2).cpu().detach().numpy().reshape(
                                                1, )
                else:
                    V_bootstrap = self.critic(
                        last_state).cpu().detach().numpy().reshape(1, )

                rewards = np.concatenate((rewards, V_bootstrap))

        Gamma = np.array([self.gamma**i for i in range(rewards.shape[0])])
        # reverse everything to use cumsum in right order, then reverse again
        Gt = np.cumsum(rewards[::-1] * Gamma[::-1])[::-1]
        # Rescale so that present reward is never discounted
        discounted_rewards = Gt / Gamma

        if bootstrap is not None:
            if bootstrap[-1] == True:
                discounted_rewards = discounted_rewards[:-1]  # drop last

        ### Wrap variables into tensors ###

        dr = torch.tensor(discounted_rewards).float().to(self.device)

        if self.discrete:
            old_states = torch.tensor(states[:-1]).to(self.device)
            new_states = torch.tensor(states[1:]).to(self.device)
        else:
            old_states = torch.tensor(states[:, :-1]).float().to(self.device)
            new_states = torch.tensor(states[:, 1:]).float().to(self.device)

        done = torch.LongTensor(done.astype(int)).to(self.device)
        log_probs = torch.stack(log_probs).to(self.device)

        ### Update critic and then actor ###

        critic_loss = self.update_critic_MC(dr, old_states)
        actor_loss = self.update_actor_MC(dr, log_probs, old_states)

        return critic_loss, actor_loss

    def update_critic_MC(self, dr, old_states):

        # Compute loss

        if self.twin:
            V1, V2 = self.critic(old_states)
            V_pred = torch.min(V1.squeeze(), V2.squeeze())
        else:
            V_pred = self.critic(old_states).squeeze()

        loss = F.mse_loss(V_pred, dr)

        # Backpropagate and update

        self.critic_optim.zero_grad()
        loss.backward()
        self.critic_optim.step()

        return loss.item()

    def update_actor_MC(self, dr, log_probs, old_states):

        # Compute gradient

        if self.twin:
            V1, V2 = self.critic(old_states)
            V_pred = torch.min(V1.squeeze(), V2.squeeze())
        else:
            V_pred = self.critic(old_states).squeeze()

        A = dr - V_pred
        policy_gradient = -log_probs * A
        policy_grad = torch.sum(policy_gradient)

        # Backpropagate and update

        self.actor_optim.zero_grad()
        policy_grad.backward()
        self.actor_optim.step()

        return policy_grad.item()
예제 #3
0
class A2C():
    def __init__(self, state_dim, action_dim, action_lim, update_type='soft',
                lr_actor=1e-4, lr_critic=1e-3, tau=1e-3,
                mem_size=1e6, batch_size=256, gamma=0.99,
                other_cars=False, ego_dim=None):
        self.device = torch.device("cuda:0" if torch.cuda.is_available()
                                        else "cpu")

        self.joint_model = False
        if len(state_dim) == 3:
            self.model = ActorCriticCNN(state_dim, action_dim, action_lim)
            self.model_optim = optim.Adam(self.model.parameters(), lr=lr_actor)

            self.target_model = ActorCriticCNN(state_dim, action_dim, action_lim)
            self.target_model.load_state_dict(self.model.state_dict())

            self.model.to(self.device)
            self.target_model.to(self.device)

            self.joint_model = True
        else:
            self.actor = Actor(state_dim, action_dim, action_lim, other_cars=other_cars, ego_dim=ego_dim)
            self.actor_optim = optim.Adam(self.actor.parameters(), lr=lr_actor)
            self.target_actor = Actor(state_dim, action_dim, action_lim, other_cars=other_cars, ego_dim=ego_dim)
            self.target_actor.load_state_dict(self.actor.state_dict())
            self.target_actor.eval()

            self.critic = Critic(state_dim, action_dim, other_cars=other_cars, ego_dim=ego_dim)
            self.critic_optim = optim.Adam(self.critic.parameters(), lr=lr_critic, weight_decay=1e-2)
            self.target_critic = Critic(state_dim, action_dim, other_cars=other_cars, ego_dim=ego_dim)
            self.target_critic.load_state_dict(self.critic.state_dict())
            self.target_critic.eval()

            self.actor.to(self.device)
            self.target_actor.to(self.device)
            self.critic.to(self.device)
            self.target_critic.to(self.device)

        self.action_lim = action_lim
        self.tau = tau # hard update if tau is None
        self.update_type = update_type
        self.batch_size = batch_size
        self.gamma = gamma

        if self.joint_model:
            mem_size = mem_size//100
        self.memory = Memory(int(mem_size), action_dim, state_dim)

        mu = np.zeros(action_dim)
        sigma = np.array([0.5, 0.05])
        self.noise = OrnsteinUhlenbeckActionNoise(mu, sigma)
        self.target_noise = OrnsteinUhlenbeckActionNoise(mu, sigma)

        self.initialised = True
        self.training = False

    def select_action(self, obs):
        with torch.no_grad():
            obs = torch.FloatTensor(np.expand_dims(obs, axis=0)).to(self.device)
            if self.joint_model:
                action, _ = self.model(obs)
                action = action.data.cpu().numpy().flatten()
            else:
                action = self.actor(obs).data.cpu().numpy().flatten()

        if self.training:
            action += self.noise()
            return action
        else:
            return action

    def append(self, obs0, action, reward, obs1, terminal1):
        self.memory.append(obs0, action, reward, obs1, terminal1)

    def reset_noise(self):
        self.noise.reset()
        self.target_noise.reset()

    def train(self):
        if self.joint_model:
            self.model.train()
            self.target_model.train()
        else:
            self.actor.train()
            self.target_actor.train()
            self.critic.train()
            self.target_critic.train()

        self.training = True

    def eval(self):
        if self.joint_model:
            self.model.eval()
            self.target_model.eval()
        else:
            self.actor.eval()
            self.target_actor.eval()
            self.critic.eval()
            self.target_critic.eval()

        self.training = False

    def save(self, folder, episode, previous=None, solved=False):
        filename = lambda type, ep : folder + '%s' % type + \
                                    (not solved) * ('_ep%d' % (ep)) + \
                                    (solved * '_solved') + '.pth'

        if self.joint_model:
            torch.save(self.model.state_dict(), filename('model', episode))
            torch.save(self.target_model.state_dict(), filename('target_model', episode))
        else:
            torch.save(self.actor.state_dict(), filename('actor', episode))
            torch.save(self.target_actor.state_dict(), filename('target_actor', episode))

            torch.save(self.critic.state_dict(), filename('critic', episode))
            torch.save(self.target_critic.state_dict(), filename('target_critic', episode))

        if previous is not None and previous > 0:
            if self.joint_model:
                os.remove(filename('model', previous))
                os.remove(filename('target_model', previous))
            else:
                os.remove(filename('actor', previous))
                os.remove(filename('target_actor', previous))
                os.remove(filename('critic', previous))
                os.remove(filename('target_critic', previous))

    def load_actor(self, actor_filepath):
        qualifier = '_' + actor_filepath.split("_")[-1]
        folder = actor_filepath[:actor_filepath.rfind("/")+1]
        filename = lambda type : folder + '%s' % type + qualifier

        if self.joint_model:
            self.model.load_state_dict(torch.load(filename('model'),
                                                    map_location=self.device))
            self.target_model.load_state_dict(torch.load(filename('target_model'),
                                                    map_location=self.device))
        else:
            self.actor.load_state_dict(torch.load(filename('actor'),
                                                    map_location=self.device))
            self.target_actor.load_state_dict(torch.load(filename('target_actor'),
                                                    map_location=self.device))

    def load_all(self, actor_filepath):
        self.load_actor(actor_filepath)
        qualifier = '_' + actor_filepath.split("_")[-1]
        folder = actor_filepath[:actor_filepath.rfind("/")+1]
        filename = lambda type : folder + '%s' % type + qualifier

        if not self.joint_model:
            self.critic.load_state_dict(torch.load(filename('critic'),
                                                    map_location=self.device))
            self.target_critic.load_state_dict(torch.load(filename('target_critic'),
                                                    map_location=self.device))

    def update(self, target_noise=True):
        try:
            minibatch = self.memory.sample(self.batch_size) # dict of ndarrays
        except ValueError as e:
            print('Replay memory not big enough. Continue.')
            return None, None

        states = Variable(torch.FloatTensor(minibatch['obs0'])).to(self.device)
        actions = Variable(torch.FloatTensor(minibatch['actions'])).to(self.device)
        rewards = Variable(torch.FloatTensor(minibatch['rewards'])).to(self.device)
        next_states = Variable(torch.FloatTensor(minibatch['obs1'])).to(self.device)
        terminals = Variable(torch.FloatTensor(minibatch['terminals1'])).to(self.device)

        if self.joint_model:
            target_actions, _ = self.target_model(next_states)
            if target_noise:
                for sample in range(target_actions.shape[0]):
                    target_actions[sample] += self.target_noise()
                    target_actions[sample].clamp(-self.action_lim, self.action_lim)
            _, target_qvals = self.target_model(next_states, target_actions=target_actions)
            y = rewards + self.gamma * (1 - terminals) * target_qvals

            _, model_qvals = self.model(states, target_actions=actions)
            value_loss = F.mse_loss(y, model_qvals)
            model_actions, _ = self.model(states)
            _, model_qvals = self.model(states, target_actions=model_actions)
            action_loss = -model_qvals.mean()

            self.model_optim.zero_grad()
            (value_loss + action_loss).backward()
            self.model_optim.step()
        else:
            target_actions = self.target_actor(next_states)
            if target_noise:
                for sample in range(target_actions.shape[0]):
                    target_actions[sample] += self.target_noise()
                    target_actions[sample].clamp(-self.action_lim, self.action_lim)
            target_critic_qvals = self.target_critic(next_states, target_actions)
            y = rewards + self.gamma * (1 - terminals) * target_critic_qvals

            # optimise critic
            critic_qvals = self.critic(states, actions)
            value_loss = F.mse_loss(y, critic_qvals)
            self.critic_optim.zero_grad()
            value_loss.backward()
            self.critic_optim.step()

            # optimise actor
            action_loss = -self.critic(states, self.actor(states)).mean()
            self.actor_optim.zero_grad()
            action_loss.backward()
            self.actor_optim.step()

        # optimise target networks
        if self.update_type == 'soft':
            if self.joint_model:
                soft_update(self.target_model, self.model, self.tau)
            else:
                soft_update(self.target_actor, self.actor, self.tau)
                soft_update(self.target_critic, self.critic, self.tau)
        else:
            if self.joint_model:
                hard_update(self.target_model, self.model)
            else:
                hard_update(self.target_actor, self.actor)
                hard_update(self.target_critic, self.critic)

        return action_loss.item(), value_loss.item()