Exemple #1
0
class DDPGAgent:
    #def __init__(self, in_actor=14, hidden_in_actor=16, hidden_out_actor=8, out_actor=2,
    #in_critic=20, hidden_in_critic=32, hidden_out_critic=16,
    #lr_actor=1.0e-2, lr_critic=1.0e-2):
    def __init__(self, state_size, obs_size, action_size, num_agents):
        super(DDPGAgent, self).__init__()

        #self.actor = Network(in_actor, hidden_in_actor, hidden_out_actor, out_actor, actor=True).to(device)
        #self.critic = Network(in_critic, hidden_in_critic, hidden_out_critic, 1).to(device)
        #self.target_actor = Network(in_actor, hidden_in_actor, hidden_out_actor, out_actor, actor=True).to(device)
        #self.target_critic = Network(in_critic, hidden_in_critic, hidden_out_critic, 1).to(device)

        self.actor = ActorNetwork(obs_size, action_size).to(device)
        self.critic = CriticNetwork(state_size,
                                    action_size * num_agents).to(device)
        self.target_actor = ActorNetwork(obs_size, action_size).to(device)
        self.target_critic = CriticNetwork(state_size,
                                           action_size * num_agents).to(device)

        #self.noise = OUNoise(out_actor, scale=1.0 )
        self.noise = OUNoise(action_size, scale=1.0)

        # initialize targets same as original networks
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

        self.actor_optimizer = Adam(self.actor.parameters(), lr=LR_ACTOR)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=LR_CRITIC,
                                     weight_decay=WEIGHT_DECAY)

    def act(self, obs, noise=0.0):
        if type(obs) == np.ndarray:
            obs = torch.from_numpy(obs).float().to(device)
        #self.actor.eval()
        action = self.actor(obs)
        action += noise * self.noise.noise()
        #self.actor.train()
        #return action.cpu().data.numpy()
        return action

    def target_act(self, obs, noise=0.0):
        if type(obs) == np.ndarray:
            obs = torch.from_numpy(obs).float().to(device)
        #obs = obs.to(device)
        #self.target_actor.eval()
        #action = self.target_actor(obs) + noise*self.noise.noise()
        action = self.target_actor(obs)
        action += noise * self.noise.noise()
        #self.target_actor.train()
        #return action.cpu().data.numpy()
        return action
Exemple #2
0
class Learner:
    def __init__(self, opt, q_batch):
        self.opt = opt
        self.q_batch = q_batch

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.env = gym.make(self.opt.env)
        self.env.seed(self.opt.seed)
        self.n_state = self.env.observation_space.shape[0]
        self.n_act = self.env.action_space.n

        self.actor = ActorNetwork(self.n_state, self.n_act).to(self.device)
        self.critic = CriticNetwork(self.n_state).to(self.device)
        self.actor.share_memory()
        self.critic.share_memory()
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=opt.lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=opt.lr)

    def learning(self):
        torch.manual_seed(self.opt.seed)

        while True:
            # batch-trace
            states, actions, rewards = self.q_batch.get(block=True)

            onehot_actions = torch.FloatTensor(
                index2onehot(actions, self.n_act)).to(self.device)

            # update actor network
            self.actor_optimizer.zero_grad()
            action_log_probs = self.actor(states)
            action_log_probs = torch.sum(action_log_probs * onehot_actions, 1)
            values = self.critic(states)
            advantages = rewards - values.detach()
            pg_loss = -torch.sum(action_log_probs * advantages)
            actor_loss = pg_loss
            actor_loss.backward()
            self.actor_optimizer.step()

            # update critic network
            self.critic_optimizer.zero_grad()
            target_values = rewards
            critic_loss = nn.MSELoss()(values, target_values)
            critic_loss.backward()
            self.critic_optimizer.step()
Exemple #3
0
class Critic():
    def __init__(self, state_size, action_size, random_seed, learning_rate,
                 weight_decay, device):

        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(random_seed)
        self.learning_rate = learning_rate

        self.critic_local = CriticNetwork(state_size, action_size,
                                          random_seed).to(device)
        self.critic_target = CriticNetwork(state_size, action_size,
                                           random_seed).to(device)
        hard_update(self.critic_target, self.critic_local)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=self.learning_rate,
                                           weight_decay=weight_decay)
Exemple #4
0
class Agent():
    """Interacts with and learns from the environment."""
    def __init__(self, state_size, action_size, random_seed, agent_size=1):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            random_seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(random_seed)
        self.agent_size = agent_size

        self.local_actor = ActorNetwork(state_size, action_size,
                                        random_seed).to(device)
        self.target_actor = ActorNetwork(state_size, action_size,
                                         random_seed).to(device)
        self.local_critic = CriticNetwork(state_size, action_size,
                                          random_seed).to(device)
        self.target_critic = CriticNetwork(state_size, action_size,
                                           random_seed).to(device)

        self.opt_actor = optim.Adam(self.local_actor.parameters(), lr=LR_ACTOR)
        self.opt_critic = optim.Adam(self.local_critic.parameters(),
                                     lr=LR_CRITIC,
                                     weight_decay=WEIGHT_DECAY)

        self.noise = OUNoise(action_size, random_seed)

        # Replay memory
        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE,
                                   random_seed)

    def save_experience(self, state, action, reward, next_state, done):
        """Save experience in replay memory, and use random sample from buffer to learn."""
        # Save experience
        self.memory.add(state, action, reward, next_state, done)

    def multi_step(self, t):
        # Learn, if enough samples are available in memory
        if len(self.memory) > BATCH_SIZE:
            if t % 20 == 0:
                for i in range(0, 10):
                    self.learn(self.memory.sample(), GAMMA)
            else:
                pass

    def act(self, state, add_noise=True):
        """Returns actions for given state as per current policy."""
        state = torch.from_numpy(state).float().to(device)
        self.local_actor.eval()
        with torch.no_grad():
            action = self.local_actor(state).cpu().data.numpy()
        self.local_actor.train()
        if add_noise:
            for a in range(0, self.agent_size):
                action[a] += self.noise.sample()
        return np.clip(action, -1, 1)  # all actions between -1 and 1

    def reset(self):
        self.noise.reset()

    def learn(self, experiences, gamma):
        """
        Target and Local Critics-Actors are used to sove the moving targets problem.
        TargetActor generates the next action, and TargetCritic generates the corresponding Q-value.
        This function updates policy and value parameters using given batch of experience tuples.

        Q_targets = r + gamma * critic_t(next_state, actor_t(next_state))

        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = self.target_actor(next_states)
        Q_targets_next = self.target_critic(next_states, actions_next)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.local_critic(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.opt_critic.zero_grad()
        critic_loss.backward()
        #use gradient clipping when training the critic network
        torch.nn.utils.clip_grad_norm_(self.local_critic.parameters(), 1)
        self.opt_critic.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = self.local_actor(states)
        actor_loss = -self.local_critic(states, actions_pred).mean()
        # Minimize the loss
        self.opt_actor.zero_grad()
        torch.nn.utils.clip_grad_norm_(self.local_actor.parameters(), 1)
        actor_loss.backward()
        self.opt_actor.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.local_critic, self.target_critic, TAU)
        self.soft_update(self.local_actor, self.target_actor, TAU)

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        this function manages the update of local and target models syncing
        theta_target = tau*theta_local + (1 - tau)*theta_target
        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
class Agent():
    def __init__(self, state_size, action_size, n_agents, seed):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)

        self.stacked_state_size = state_size * n_agents
        self.stacked_action_size = action_size * n_agents

        # Actor networks
        self.actor_local = ActorNetwork(state_size, action_size,
                                        seed).to(device)
        self.actor_target = ActorNetwork(state_size, action_size,
                                         seed).to(device)
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(),
                                          lr=ACTOR_LR)

        # Critic networks
        self.critic_local = CriticNetwork(self.stacked_state_size,
                                          self.stacked_action_size,
                                          seed).to(device)
        self.critic_target = CriticNetwork(self.stacked_state_size,
                                           self.stacked_action_size,
                                           seed).to(device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=CRITIC_LR)

        # OUNoise
        self.exploration_noise = OUNoise(action_size, seed)

    def act(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)

        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()

        # Add exploration noise
        action += self.exploration_noise.sample()

        return np.clip(action, -1, 1)

    def update(self, states, current_agent_states, actions,
               current_agent_actions, target_next_actions, rewards,
               current_agent_rewards, next_states, dones, current_agent_dones,
               action_preds):
        flatten_states = torch.reshape(states, shape=(BATCH_SIZE, -1))
        flatten_next_states = torch.reshape(next_states,
                                            shape=(BATCH_SIZE, -1))
        flatten_actions = torch.reshape(actions, shape=(BATCH_SIZE, -1))

        y = current_agent_rewards + GAMMA * self.critic_target(
            flatten_next_states,
            target_next_actions) * (1 - current_agent_dones)

        # Critic loss
        critic_loss = F.mse_loss(
            y, self.critic_local(flatten_states, flatten_actions))

        # Critic backprop
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Actor loss
        actor_loss = -self.critic_local(flatten_states, action_preds).mean()

        # Actor backprop
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # Soft updates
        self.update_target_network()

    def update_target_network(self):
        for target_param, local_param in zip(self.actor_target.parameters(),
                                             self.actor_local.parameters()):
            target_param.data.copy_(TAU * local_param.data +
                                    (1.0 - TAU) * target_param.data)

        for target_param, local_param in zip(self.critic_target.parameters(),
                                             self.critic_local.parameters()):
            target_param.data.copy_(TAU * local_param.data +
                                    (1.0 - TAU) * target_param.data)
class Critic:

    def __init__(self,
        device,
        state_size, action_size, random_seed,
        gamma, TAU, lr, weight_decay,
        checkpoint_folder = './Saved_Model/'):

        self.DEVICE = device

        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(random_seed)

        # Hyperparameters
        self.GAMMA = gamma
        self.TAU = TAU
        self.LR = lr
        self.WEIGHT_DECAY = weight_decay

        self.CHECKPOINT_FOLDER = checkpoint_folder

        # Critic Network (w/ Target Network)
        self.local = CriticNetwork(state_size, action_size, random_seed).to(self.DEVICE)
        self.target = CriticNetwork(state_size, action_size, random_seed).to(self.DEVICE)
        self.optimizer = optim.Adam(self.local.parameters(), lr=self.LR, weight_decay=self.WEIGHT_DECAY)

        self.checkpoint_full_name = self.CHECKPOINT_FOLDER + 'checkpoint_critic.pth'
        if os.path.isfile(self.checkpoint_full_name):
            self.local.load_state_dict(torch.load(self.checkpoint_full_name))
            self.target.load_state_dict(torch.load(self.checkpoint_full_name))

    def step(self, actor, memory):
        # Learn, if enough samples are available in memory
        experiences = memory.sample()
        if not experiences:
            return

        self.learn(actor, experiences)

    def learn(self, actor, experiences):
        """Update policy and value parameters using given batch of experience tuples.
        Q_targets = r + γ * target(next_state, actor_target(next_state))
        where:
            actor_target(state) -> action
            target(state, action) -> Q-value
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """

        states, actions, rewards, next_states, dones = experiences

        # ---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models
        actions_next = actor.target(next_states)
        Q_targets_next = self.target(next_states, actions_next)
        # Compute Q targets for current states (y_i)
        Q_targets = rewards + (self.GAMMA * Q_targets_next * (1 - dones))
        # Compute critic loss
        Q_expected = self.local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        critic_loss.backward()
        # torch.nn.utils.clip_grad_norm(self.local.parameters(), 1)
        self.optimizer.step()

        # ---------------------------- update actor ---------------------------- #
        # Compute actor loss
        actions_pred = actor.local(states)
        actor_loss = - self.local(states, actions_pred).mean()
        # Minimize the loss
        actor.optimizer.zero_grad()
        actor_loss.backward()
        actor.optimizer.step()

        # ----------------------- update target networks ----------------------- #
        self.soft_update(self.local, self.target)
        self.soft_update(actor.local, actor.target)

    def soft_update(self, local_model, target_model):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model: PyTorch model (weights will be copied from)
            target_model: PyTorch model (weights will be copied to)
            tau (float): interpolation parameter
        """
        tau = self.TAU
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)

    def checkpoint(self):
        torch.save(self.local.state_dict(), self.checkpoint_full_name)
Exemple #7
0
class DDPGAgent:
    def __init__(self,
                 state_size,
                 action_size,
                 num_agents,
                 hidden_in_actor=512,
                 hidden_out_actor=256,
                 lr_actor=1e-4,
                 hidden_in_critic=512,
                 hidden_out_critic=256,
                 lr_critic=3e-4,
                 weight_decay_critic=0,
                 seed=1,
                 device='cpu'):
        super(DDPGAgent, self).__init__()

        self.device = device

        # Actor
        self.actor = ActorNetwork(state_size, hidden_in_actor,
                                  hidden_out_actor, action_size,
                                  seed).to(device)
        self.target_actor = ActorNetwork(state_size, hidden_in_actor,
                                         hidden_out_actor, action_size,
                                         seed).to(device)
        self.actor_optimizer = Adam(self.actor.parameters(), lr=lr_actor)

        # Target
        self.critic = CriticNetwork(state_size, action_size, num_agents,
                                    hidden_in_critic, hidden_out_critic,
                                    seed).to(device)
        self.target_critic = CriticNetwork(state_size, action_size, num_agents,
                                           hidden_in_critic, hidden_out_critic,
                                           seed).to(device)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=lr_critic,
                                     weight_decay=weight_decay_critic)

        # Noise
        self.noise = OUNoise(action_size, seed, scale=1.0)

        # initialize targets same as original networks
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

    def reset(self):
        self.noise.reset()

    def act(self, obs, noise_factor=0.0):

        if torch.is_tensor(obs):
            states = obs
        else:
            states = torch.from_numpy(obs).float().to(self.device)

        self.actor.eval()
        with torch.no_grad():
            actions = self.actor(states).cpu().data.numpy()
        self.actor.train()
        actions += noise_factor * self.noise.sample()
        return np.clip(actions, -1, 1)

    def target_act(self, obs):

        if torch.is_tensor(obs):
            states = obs
        else:
            states = torch.from_numpy(obs).float().to(self.device)

        self.target_actor.eval()
        with torch.no_grad():
            actions = self.target_actor(states).cpu().data.numpy()
        self.target_actor.train()
        return np.clip(actions, -1, 1)
class Agent():
    """Interacts with and learns from the environment."""

    def __init__(self, state_size, action_size, memory, seed=None):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        
        self.state_size = state_size
        self.action_size = action_size
        
        if seed is not None:
            self.seed = seed

        # create the local and target actor networks
        self.actor_local = ActorNetwork(state_size, action_size, seed).to(device)
        self.actor_target = ActorNetwork(state_size, action_size, seed).to(device)
        
        # create the local and target critic networks
        self.critic_local = CriticNetwork(state_size, action_size, seed).to(device)
        self.critic_target = CriticNetwork(state_size, action_size, seed).to(device)
        
        # optimizers for local actor and critic 
        self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=LR)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(), lr=LR, weight_decay=0.0)
        
        # MSE loss for updating the critic
        # self.critic_loss_function = nn.MSELoss()
        self.critic_loss_function = nn.SmoothL1Loss()

        # copy the local networks weights to the target network 
        self.copy_weights_from_local_to_target()
        
        # Replay memory
        self.memory = memory
        
        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0
        
        # init the noise class to sample from
        self.noise = GaussianNoise(self.action_size)
        
    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)
        
        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > BATCH_SIZE:
                for _ in range(LEARN_TIMES):
                    experiences = self.memory.sample()
                    self.learn(experiences, GAMMA)
                self.soft_update_all()


    def copy_weights_from_local_to_target(self):
        # ensure that the local and target networks are initialized with the same random weights
        # or copy you saved weights after loading into local
        for target_param, param in zip(self.actor_target.parameters(), self.actor_local.parameters()):
            target_param.data.copy_(param.data)
        for target_param, param in zip(self.critic_target.parameters(), self.critic_local.parameters()):
            target_param.data.copy_(param.data)

    def act(self, state, add_noise=False):
        """Returns actions for given state as per current policy. 
        
        Params
        ======
            state (array_like): current state
            eps (float): epsilon, for epsilon-greedy action selection
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        
        # get predicted actions for current state from actor network
        self.actor_local.eval()
        with torch.no_grad():
            action_values = self.actor_local(state)
        self.actor_local.train()

        # take the predicted actions and add noise, used as exploration in a continuous environment
        action_values = action_values.cpu().data.numpy()
        
        if add_noise == True:
            action_values += self.noise.sample()
        
        return action_values
        
    def learn(self, experiences, gamma):
        """Update value parameters using given batch of experience tuples.

        Params
        ======
            experiences (Tuple[torch.Variable]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        
        # unpack the experiences tuple 
        states, actions, rewards, next_states, dones = experiences
        
        # computer the loss for the actor network per the DDPG algorithm
        actor_local_predicted_actions = self.actor_local(states)
        policy_loss = -self.critic_local(states, actor_local_predicted_actions).mean()
        
        # compute the loss for the critic network per the DDPG algorithm
        predicted_Q_vals = self.critic_local(states, actions)
        predicted_actions = self.actor_target(next_states)
        Q_next = self.critic_target(next_states, predicted_actions)
        Q_targets = rewards + (gamma * Q_next * (1 - dones))
        
        critic_loss = self.critic_loss_function(predicted_Q_vals, Q_targets)
        
        # update the networks
        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        self.actor_optimizer.step()
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_local.parameters(), 1)
        self.critic_optimizer.step()
        
    
    def soft_update_all(self):
        # and soft update the target networks
        self.soft_update(self.critic_local, self.critic_target, TAU)
        self.soft_update(self.actor_local, self.actor_target, TAU)
        
        
    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter   # use percent tau local_param.data and rest target_param.data
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
Exemple #9
0
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.autograd import Variable

import gym
import numpy as np
import matplotlib.pyplot as plt

from model import ActorNetwork, CriticNetwork

actor = ActorNetwork(4, 2)
critic = CriticNetwork(4)
actor_optimizer = optim.Adam(actor.parameters(), lr=1e-4)
critic_optimizer = optim.Adam(critic.parameters(), lr=8e-4)
env = gym.make('CartPole-v0')
GAMMA = 0.99
N_EPISODES = 20000
LOG_STEPS = 100
SAVE_STEPS = 100


def select_action(S):
    '''
    select action based on currentr state
    args:
        S: current state
    returns:
        action to take, log probability of the chosen action
    '''
Exemple #10
0
class Learner(object):
    def __init__(self, opt, q_batch):
        self.opt = opt
        self.q_batch = q_batch

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.env = gym.make(self.opt.env)
        self.env.seed(self.opt.seed)
        self.n_state = self.env.observation_space.shape[0]
        self.n_act = self.env.action_space.n

        self.actor = ActorNetwork(self.n_state, self.n_act).to(self.device)
        self.critic = CriticNetwork(self.n_state).to(self.device)
        self.actor.share_memory()
        self.critic.share_memory()
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=opt.lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=opt.lr)

    def learning(self):
        torch.manual_seed(self.opt.seed)
        coef_hat = torch.FloatTensor([self.opt.coef_hat]*self.opt.batch_size*self.opt.n_step).view(self.opt.batch_size, self.opt.n_step)
        rho_hat = torch.FloatTensor([self.opt.rho_hat]*self.opt.batch_size*self.opt.n_step).view(self.opt.batch_size, self.opt.n_step)
        while True:
            # batch-trace
            states, actions, rewards, dones, action_log_probs = self.q_batch.get(block=True)

            logit_log_probs = self.actor(states)
            V = self.critic(states).view(self.opt.batch_size, self.opt.n_step) * (1 - dones)

            action_probs = torch.exp(action_log_probs)
            logit_probs = torch.exp(logit_log_probs)

            is_rate = torch.prod(logit_probs / (action_probs + 1e-6), dim=-1).detach()
            coef = torch.min(coef_hat, is_rate) * (1 - dones)
            rho = torch.min(rho_hat, is_rate) * (1 - dones)

            # V-trace
            v_trace = torch.zeros((self.opt.batch_size, self.opt.n_step)).to(self.device)
            target_V = V.detach()
            for rev_step in reversed(range(states.size(1) - 1)):
                v_trace[:, rev_step] = target_V[:, rev_step] \
                                       + rho[:, rev_step] * (rewards[:, rev_step] + self.opt.gamma*target_V[:, rev_step+1] - target_V[:, rev_step]) \
                                       + self.opt.gamma * coef[:, rev_step] * (v_trace[:, rev_step+1] - target_V[:, rev_step+1])

            # actor loss
            onehot_actions = torch.FloatTensor(
                idx2onehot(actions.cpu().numpy(), self.opt.batch_size, self.n_act)).to(self.device)
            logit_log_probs = torch.sum(logit_log_probs * onehot_actions, dim=-1)
            advantages = rewards + self.opt.gamma * v_trace - V
            pg_loss = -torch.sum(logit_log_probs * advantages.detach())
            actor_loss = pg_loss

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # critic
            critic_loss = torch.mean((v_trace.detach() - V)**2)

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()