示例#1
0
    def __init__(self, state_dim, action_dim, max_action, args):

        # Mu stuff
        self.mu = Actor(state_dim, action_dim, max_action, args)
        self.mu_t = Actor(state_dim, action_dim, max_action, args)
        self.mu_t.load_state_dict(self.mu.state_dict())

        # Sigma stuff
        self.log_sigma = FloatTensor(
            np.log(args.sigma_init) * np.ones(self.mu.get_size()))
        self.log_sigma_t = FloatTensor(
            np.log(args.sigma_init) * np.ones(self.mu.get_size()))

        # Optimizer
        self.opt = torch.optim.Adam(self.mu.parameters(), lr=args.actor_lr)
        self.opt.add_param_group({"params": self.log_sigma})

        # Critic stuff
        self.critic = Critic(state_dim, action_dim, max_action, args)
        self.critic_t = Critic(state_dim, action_dim, max_action, args)
        self.critic_t.load_state_dict(self.critic.state_dict())
        self.critic_opt = torch.optim.Adam(self.critic.parameters(),
                                           lr=args.critic_lr)

        # Env stuff
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action

        # Hyperparams
        self.tau = args.tau
        self.n_steps = args.n_steps
        self.discount = args.discount
        self.pop_size = args.pop_size
        self.batch_size = args.batch_size
        self.noise_clip = args.noise_clip
        self.policy_freq = args.policy_freq
        self.policy_noise = args.policy_noise
        self.reward_scale = args.reward_scale
        self.n_actor_params = self.mu.get_size()
        self.weights = FloatTensor(
            [self.discount**i for i in range(self.n_steps)])

        # cuda
        if USE_CUDA:
            self.mu.cuda()
            self.mu_t.cuda()
            self.log_sigma.cuda()
            self.log_sigma_t.cuda()
            self.critic.cuda()
            self.critic_t.cuda()
示例#2
0
    def __init__(self, state_dim, action_dim, max_action, args):

        # Actor stuff
        self.actor = GaussianActor(state_dim, action_dim, max_action, args)
        self.actor_t = GaussianActor(state_dim, action_dim, max_action, args)
        self.actor_t.load_state_dict(self.actor.state_dict())
        self.actor_opt = torch.optim.Adam(self.actor.parameters(),
                                          lr=args.actor_lr)

        # Critic stuff
        self.critic = Critic(state_dim, action_dim, max_action, args)
        self.critic_t = Critic(state_dim, action_dim, max_action, args)
        self.critic_t.load_state_dict(self.critic.state_dict())
        self.critic_opt = torch.optim.Adam(self.critic.parameters(),
                                           lr=args.critic_lr)

        # Value stuff
        self.value = Value(state_dim, action_dim, max_action, args)
        self.value_t = Value(state_dim, action_dim, max_action, args)
        self.value_t.load_state_dict(self.value.state_dict())
        self.value_opt = torch.optim.Adam(self.value.parameters(),
                                          lr=args.critic_lr)

        # Env stuff
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action

        # Hyperparams
        self.tau = args.tau
        self.n_steps = args.n_steps
        self.discount = args.discount
        self.batch_size = args.batch_size
        self.noise_clip = args.noise_clip
        self.policy_freq = args.policy_freq
        self.policy_noise = args.policy_noise
        self.reward_scale = args.reward_scale
        self.weights = FloatTensor(
            [self.discount**i for i in range(self.n_steps)])

        # cuda
        if args.use_cuda:
            self.actor.cuda()
            self.actor_t.cuda()
            self.critic.cuda()
            self.critic_t.cuda()
            self.value.cuda()
            self.value_t.cuda()
示例#3
0
class Virel(object):
    """
    VIREL-inspired Actor-Critic Algorithm: https://arxiv.org/abs/1811.01132
    """
    def __init__(self, state_dim, action_dim, max_action, args):

        # Actor stuff
        self.actor = GaussianActor(state_dim, action_dim, max_action, args)
        self.actor_t = GaussianActor(state_dim, action_dim, max_action, args)
        self.actor_t.load_state_dict(self.actor.state_dict())
        self.actor_opt = torch.optim.Adam(self.actor.parameters(),
                                          lr=args.actor_lr)

        # Critic stuff
        self.critic = Critic(state_dim, action_dim, max_action, args)
        self.critic_t = Critic(state_dim, action_dim, max_action, args)
        self.critic_t.load_state_dict(self.critic.state_dict())
        self.critic_opt = torch.optim.Adam(self.critic.parameters(),
                                           lr=args.critic_lr)

        # Value stuff
        self.value = Value(state_dim, action_dim, max_action, args)
        self.value_t = Value(state_dim, action_dim, max_action, args)
        self.value_t.load_state_dict(self.value.state_dict())
        self.value_opt = torch.optim.Adam(self.value.parameters(),
                                          lr=args.critic_lr)

        # Env stuff
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action

        # Hyperparams
        self.tau = args.tau
        self.n_steps = args.n_steps
        self.discount = args.discount
        self.batch_size = args.batch_size
        self.noise_clip = args.noise_clip
        self.policy_freq = args.policy_freq
        self.policy_noise = args.policy_noise
        self.reward_scale = args.reward_scale
        self.weights = FloatTensor(
            [self.discount**i for i in range(self.n_steps)])

        # cuda
        if args.use_cuda:
            self.actor.cuda()
            self.actor_t.cuda()
            self.critic.cuda()
            self.critic_t.cuda()
            self.value.cuda()
            self.value_t.cuda()

    def action(self, state):
        """
        Returns action given state
        """
        state = FloatTensor(state.reshape(1, -1))
        action, mu, sigma = self.actor(state)
        # print("mu", mu)
        # print("sigma", sigma ** 2)
        return action.cpu().data.numpy().flatten()

    def train(self, memory, n_iter):
        """
        Trains the model for n_iter steps
        """
        for it in range(n_iter):

            # Sample replay buffer
            states, actions, n_states, rewards, steps, dones, stops = memory.sample(
                self.batch_size)
            rewards = self.reward_scale * rewards * self.weights
            rewards = rewards.sum(dim=1, keepdim=True)

            # Select action according to policy
            n_actions = self.actor_t(n_states)[0]

            # Q target = reward + discount * min_i(Qi(next_state, pi(next_state)))
            with torch.no_grad():
                target_q1, target_q2 = self.critic_t(n_states, n_actions)
                target_q = torch.min(target_q1, target_q2)
                target_q = target_q * self.discount**(steps + 1)
                target_q = rewards + (1 - stops) * target_q

            # Get current Q estimates
            current_q1, current_q2 = self.critic(states, actions)

            # Compute critic loss
            critic_loss = nn.MSELoss()(current_q1, target_q) + nn.MSELoss()(
                current_q2, target_q)

            # Optimize the critic // M Step
            self.critic_opt.zero_grad()
            critic_loss.backward()
            self.critic_opt.step()

            # Delayed policy updates
            if it % self.policy_freq == 0:

                # actions, mus, sigmas
                actions, mus, sigmas = self.actor(states)

                # Compute actor loss with entropy
                actor_loss = -self.critic(states, actions)[0]
                actor_loss -= torch.log(sigmas**2).mean(
                    dim=1, keepdim=True) * 5  # critic_loss.detach()
                actor_loss = actor_loss.mean()

                # Optimize the actor // E Steps
                self.actor_opt.zero_grad()
                actor_loss.backward()
                self.actor_opt.step()

                # Update the frozen actor models
                for param, target_param in zip(self.actor.parameters(),
                                               self.actor_t.parameters()):
                    target_param.data.copy_(self.tau * param.data +
                                            (1 - self.tau) * target_param.data)

            # Update the frozen critic models
            for param, target_param in zip(self.critic.parameters(),
                                           self.critic_t.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

            # Update the frozen value models
            for param, target_param in zip(self.value.parameters(),
                                           self.value_t.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

    def save(self, directory):
        """
        Save the model in given folder
        """
        self.actor.save_model(directory, "actor")
        self.critic.save_model(directory, "critic")

    def load(self, directory):
        """
        Load model from folder
        """
        self.actor.load_model(directory, "actor")
        self.critic.load_model(directory, "critic")
示例#4
0
class D2TD3(object):
    """
    Double-Smoothed Twin Delayed Deep Deterministic Policy Gradient Algorithm
    """
    def __init__(self, state_dim, action_dim, max_action, args):

        # Mu stuff
        self.mu = Actor(state_dim, action_dim, max_action, args)
        self.mu_t = Actor(state_dim, action_dim, max_action, args)
        self.mu_t.load_state_dict(self.mu.state_dict())

        # Sigma stuff
        self.log_sigma = FloatTensor(
            np.log(args.sigma_init) * np.ones(self.mu.get_size()))
        self.log_sigma_t = FloatTensor(
            np.log(args.sigma_init) * np.ones(self.mu.get_size()))

        # Optimizer
        self.opt = torch.optim.Adam(self.mu.parameters(), lr=args.actor_lr)
        self.opt.add_param_group({"params": self.log_sigma})

        # Critic stuff
        self.critic = Critic(state_dim, action_dim, max_action, args)
        self.critic_t = Critic(state_dim, action_dim, max_action, args)
        self.critic_t.load_state_dict(self.critic.state_dict())
        self.critic_opt = torch.optim.Adam(self.critic.parameters(),
                                           lr=args.critic_lr)

        # Env stuff
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action

        # Hyperparams
        self.tau = args.tau
        self.n_steps = args.n_steps
        self.discount = args.discount
        self.pop_size = args.pop_size
        self.batch_size = args.batch_size
        self.noise_clip = args.noise_clip
        self.policy_freq = args.policy_freq
        self.policy_noise = args.policy_noise
        self.reward_scale = args.reward_scale
        self.n_actor_params = self.mu.get_size()
        self.weights = FloatTensor(
            [self.discount**i for i in range(self.n_steps)])

        # cuda
        if USE_CUDA:
            self.mu.cuda()
            self.mu_t.cuda()
            self.log_sigma.cuda()
            self.log_sigma_t.cuda()
            self.critic.cuda()
            self.critic_t.cuda()

    def train(self, memory, n_iter):
        """
        Trains the model for n_iter steps
        """

        for it in range(n_iter):

            # Sample replay buffer
            states, actions, n_states, rewards, steps, dones, stops = memory.sample(
                self.batch_size)
            rewards = self.reward_scale * rewards * self.weights
            rewards = rewards.sum(dim=1, keepdim=True)

            # Select policy according to noise
            # mu_t = self.mu_t.get_params()
            # log_sigma_t = self.log_sigma_t.data.cpu().numpy()
            # noise = np.random.randn(self.n_actor_params)
            # pi_t = mu_t + noise * np.exp(log_sigma_t)

            # self.mu_t.set_params(pi_t)
            n_actions = self.mu_t(n_states)
            # self.mu.set_params(mu_t)

            # Q target = reward + discount * min_i(Qi(next_state, pi(next_state)))
            with torch.no_grad():
                target_Q1, target_Q2 = self.critic_t(n_states, n_actions)
                target_Q = torch.min(target_Q1, target_Q2)
                target_Q = target_Q * self.discount**(steps + 1)
                target_Q = rewards + (1 - stops) * target_Q

            # Get current Q estimates
            current_Q1, current_Q2 = self.critic(states, actions)

            # Compute critic loss
            critic_loss = nn.MSELoss()(current_Q1, target_Q) + \
                nn.MSELoss()(current_Q2, target_Q)

            # Optimize the critic
            self.critic_opt.zero_grad()
            critic_loss.backward()
            self.critic_opt.step()

            # Delayed policy updates
            if it % self.policy_freq == 0:

                # Creating random policy
                mu = self.mu.get_params()
                log_sigma = self.log_sigma.data.cpu().numpy()
                noise = np.random.randn(self.n_actor_params)
                pi = mu + noise * np.exp(log_sigma)

                # Computing loss
                self.mu.set_params(pi)
                pi_loss = -self.critic(states, self.mu(states))[0].mean()

                # Computing gradient wrt noisy policy
                pi_loss.backward()
                pi_grad = self.mu.get_grads()
                self.mu.set_params(mu)

                # Setting gradients
                self.opt.zero_grad()
                self.mu.set_params(mu)
                self.mu.set_grads(pi_grad)
                self.log_sigma.grad = FloatTensor(pi_grad * noise *
                                                  np.exp(log_sigma))
                self.opt.step()

                # Update the frozen mu
                for param, target_param in zip(self.mu.parameters(),
                                               self.mu_t.parameters()):
                    target_param.data.copy_(self.tau * param.data +
                                            (1 - self.tau) * target_param.data)

                # Update the frozen sigma
                self.log_sigma_t = self.tau * self.log_sigma + \
                    (1 - self.tau) * self.log_sigma_t

            # Update the frozen critic models
            for param, target_param in zip(self.critic.parameters(),
                                           self.critic_t.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

    def save(self, directory):
        """
        Save the model in given folder
        """
        self.mu.save_model(directory, "actor")
        self.critic.save_model(directory, "critic")

    def load(self, directory):
        """
        Load model from folder
        """
        self.mu.load_model(directory, "actor")
        self.critic.load_model(directory, "critic")
示例#5
0
class STD3(object):
    """
    Smoothed Twin Delayed Deep Deterministic Policy Gradient Algorithm
    """
    def __init__(self, state_dim, action_dim, max_action, args):

        # Actor stuff
        self.actor = Actor(state_dim, action_dim, max_action, args)
        self.actor_t = Actor(state_dim, action_dim, max_action, args)
        self.actor_t.load_state_dict(self.actor.state_dict())
        self.actor_opt = torch.optim.Adam(self.actor.parameters(),
                                          lr=args.actor_lr)

        # Critic stuff
        self.critic = Critic(state_dim, action_dim, max_action, args)
        self.critic_t = Critic(state_dim, action_dim, max_action, args)
        self.critic_t.load_state_dict(self.critic.state_dict())
        self.critic_opt = torch.optim.Adam(self.critic.parameters(),
                                           lr=args.critic_lr)

        # Env stuff
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action

        # Hyperparams
        self.tau = args.tau
        self.n_steps = args.n_steps
        self.discount = args.discount
        self.batch_size = args.batch_size
        self.noise_clip = args.noise_clip
        self.policy_freq = args.policy_freq
        self.policy_noise = args.policy_noise
        self.reward_scale = args.reward_scale
        self.weights = FloatTensor(
            [self.discount**i for i in range(self.n_steps)])

        # cuda
        if args.use_cuda:
            self.actor.cuda()
            self.actor_t.cuda()
            self.critic.cuda()
            self.critic_t.cuda()

    def action(self, state):
        """
        Returns action given state
        """
        state = FloatTensor(state.reshape(1, -1))
        return self.actor(state).cpu().data.numpy().flatten()

    def train(self, memory, n_iter):
        """
        Trains the model for n_iter steps
        """

        for it in range(n_iter):

            # Sample replay buffer
            states, actions, n_states, rewards, steps, dones, stops = memory.sample(
                self.batch_size)
            print("before:", rewards)
            rewards = self.reward_scale * rewards * self.weights
            rewards = rewards.sum(dim=1, keepdim=True)
            print("after:", rewards)

            # Select action according to policy and add clipped noise
            noise = np.clip(
                np.random.normal(0,
                                 self.policy_noise,
                                 size=(self.batch_size, self.action_dim)),
                -self.noise_clip, self.noise_clip)
            n_actions = self.actor_t(n_states)  # + FloatTensor(noise)
            n_actions = n_actions.clamp(-self.max_action, self.max_action)

            # Q target = reward + discount * min_i(Qi(next_state, pi(next_state)))
            with torch.no_grad():
                target_Q1, target_Q2 = self.critic_t(n_states, n_actions)
                target_Q = torch.min(target_Q1, target_Q2)
                target_Q = target_Q * self.discount**(steps + 1)
                target_Q = rewards.sum + (1 - stops) * target_Q

            # Get current Q estimates
            current_Q1, current_Q2 = self.critic(states, actions)

            # Compute critic loss
            critic_loss = nn.MSELoss()(current_Q1, target_Q) + \
                nn.MSELoss()(current_Q2, target_Q)

            # Optimize the critic
            self.critic_opt.zero_grad()
            critic_loss.backward()
            self.critic_opt.step()

            # Delayed policy updates
            if it % self.policy_freq == 0:

                # Compute actor loss
                # noise = np.clip(np.random.normal(0, self.policy_noise, size=(
                #     self.batch_size, self.action_dim)), -self.noise_clip, self.noise_clip)
                # n_actions = self.actor(states) + FloatTensor(noise)
                # n_actions = n_actions.clamp(-self.max_action, self.max_action)
                # actor_loss = -self.critic(states, n_actions)[0].mean()

                actor_params = self.actor.get_params()
                grads = np.zeros(self.actor.get_size())

                for _ in range(5):

                    noise = np.random.normal(0,
                                             self.policy_noise,
                                             size=(self.actor.get_size()))
                    self.actor.set_params(actor_params +
                                          noise * self.policy_noise)

                    n_actions = self.actor(states)  # + FloatTensor(noise)
                    n_actions = n_actions.clamp(-self.max_action,
                                                self.max_action)

                    self.actor_opt.zero_grad()
                    actor_loss = -self.critic(states, n_actions)[0].mean()
                    actor_loss.backward()

                    # * np.exp(- noise ** 2 / (2 * self.policy_noise ** 2)
                    grads += self.actor.get_grads()
                    #        ) / np.sqrt(2 * np.pi) / self.policy_noise

                self.actor_opt.zero_grad()
                self.actor.set_params(actor_params)
                self.actor.set_grads(grads / 5)
                self.actor_opt.step()

                # Optimize the actor
                # self.actor_opt.zero_grad()
                # actor_loss.backward()
                # self.actor_opt.step()

                # Update the frozen actor models
                for param, target_param in zip(self.actor.parameters(),
                                               self.actor_t.parameters()):
                    target_param.data.copy_(self.tau * param.data +
                                            (1 - self.tau) * target_param.data)

            # Update the frozen critic models
            for param, target_param in zip(self.critic.parameters(),
                                           self.critic_t.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

    def save(self, directory):
        """
        Save the model in given folder
        """
        self.actor.save_model(directory, "actor")
        self.critic.save_model(directory, "critic")

    def load(self, directory):
        """
        Load model from folder
        """
        self.actor.load_model(directory, "actor")
        self.critic.load_model(directory, "critic")
示例#6
0
class NASTD3(object):
    """
    Twin Delayed Deep Deterministic Policy Gradient Algorithm with n-step return
    """
    def __init__(self, state_dim, action_dim, max_action, args):

        # Actor stuff
        self.actor = NASActor(state_dim, action_dim, max_action, args)
        self.actor_t = NASActor(state_dim, action_dim, max_action, args)
        self.actor_t.load_state_dict(self.actor.state_dict())
        self.actor_opt = torch.optim.Adam(self.actor.parameters(),
                                          lr=args.actor_lr)

        # Critic stuff
        self.critic = Critic(state_dim, action_dim, max_action, args)
        self.critic_t = Critic(state_dim, action_dim, max_action, args)
        self.critic_t.load_state_dict(self.critic.state_dict())
        self.critic_opt = torch.optim.Adam(self.critic.parameters(),
                                           lr=args.critic_lr)

        # Env stuff
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action

        # Hyperparams
        self.tau = args.tau
        self.n_steps = args.n_steps
        self.discount = args.discount
        self.batch_size = args.batch_size
        self.noise_clip = args.noise_clip
        self.policy_freq = args.policy_freq
        self.policy_noise = args.policy_noise
        self.reward_scale = args.reward_scale
        self.weights = FloatTensor(
            [self.discount**i for i in range(self.n_steps)])

        # cuda
        if args.use_cuda:
            self.actor.cuda()
            self.actor_t.cuda()
            self.critic.cuda()
            self.critic_t.cuda()

    def action(self, state):
        """
        Returns action given state
        """
        state = FloatTensor(state.reshape(1, -1))
        return self.actor(state).cpu().data.numpy().flatten()

    def train(self, memory, n_iter):
        """
        Trains the model for n_iter steps
        """
        critic_losses = []
        actor_losses = []
        for it in tqdm(range(n_iter)):

            # Sample replay buffer
            states, actions, n_states, rewards, steps, dones, stops = memory.sample(
                self.batch_size)
            rewards = self.reward_scale * rewards * self.weights
            rewards = rewards.sum(dim=1, keepdim=True)

            # Select action according to policy
            n_actions = self.actor_t(n_states)

            # Q target = reward + discount * min_i(Qi(next_state, pi(next_state)))
            with torch.no_grad():
                target_q1, target_q2 = self.critic_t(n_states, n_actions)
                target_q = torch.min(target_q1, target_q2)
                target_q = rewards + (1 - stops) * target_q * self.discount**(
                    steps + 1)

            # Get current Q estimates
            current_q1, current_q2 = self.critic(states, actions)

            # Compute critic loss
            critic_loss = nn.MSELoss()(current_q1, target_q) + \
                nn.MSELoss()(current_q2, target_q)
            critic_losses.append(critic_loss.data.cpu().numpy())

            # Optimize the critic
            self.critic_opt.zero_grad()
            critic_loss.backward()
            self.critic_opt.step()

            # Delayed policy updates
            if it % self.policy_freq == 0:

                # Compute actor loss
                actor_loss = -self.critic(states, self.actor(states))[0].mean()
                actor_losses.append(actor_loss.data.cpu().numpy())

                # Optimize the actor
                self.actor_opt.zero_grad()
                actor_loss.backward()
                self.actor_opt.step()

                # Normalize alphas
                self.actor.normalize_alpha()
                self.actor_t.normalize_alpha()

                # Update the frozen actor models
                for param, target_param in zip(self.actor.parameters(),
                                               self.actor_t.parameters()):
                    target_param.data.copy_(self.tau * param.data +
                                            (1 - self.tau) * target_param.data)

            # Update the frozen critic models
            for param, target_param in zip(self.critic.parameters(),
                                           self.critic_t.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

        return np.mean(critic_losses), np.mean(actor_losses)

    def save(self, directory):
        """
        Save the model in given folder
        """
        self.actor.save_model(directory, "actor")
        self.critic.save_model(directory, "critic")

    def load(self, directory):
        """
        Load model from folder
        """
        self.actor.load_model(directory, "actor")
        self.critic.load_model(directory, "critic")
示例#7
0
class MPO(object):
    """
    MPO-inspired Actor-Critic Algorithm: https://arxiv.org/pdf/1806.06920.pdf
    """
    def __init__(self, state_dim, action_dim, max_action, args):

        # Actor stuff
        self.pi = GaussianActor(state_dim, action_dim, max_action, args)
        self.pi_t = GaussianActor(state_dim, action_dim, max_action, args)
        self.pi_t.load_state_dict(self.pi.state_dict())
        self.pi_opt = torch.optim.Adam(self.pi.parameters(), lr=args.actor_lr)

        # Variational policy stuff
        self.q = GaussianActor(state_dim, action_dim, max_action, args)
        self.q_t = GaussianActor(state_dim, action_dim, max_action, args)
        self.q.load_state_dict(self.pi.state_dict())
        self.q_t.load_state_dict(self.pi.state_dict())
        self.q_opt = torch.optim.Adam(self.q.parameters(), lr=args.actor_lr)

        # Critic stuff
        self.critic = Critic(state_dim, action_dim, max_action, args)
        self.critic_t = Critic(state_dim, action_dim, max_action, args)
        self.critic_t.load_state_dict(self.critic.state_dict())
        self.critic_opt = torch.optim.Adam(self.critic.parameters(),
                                           lr=args.critic_lr)

        # Env stuff
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action

        # Hyperparams
        self.tau = args.tau
        self.alpha = args.alpha
        self.n_steps = args.n_steps
        self.discount = args.discount
        self.batch_size = args.batch_size
        self.noise_clip = args.noise_clip
        self.policy_freq = args.policy_freq
        self.policy_noise = args.policy_noise
        self.reward_scale = args.reward_scale
        self.weights = FloatTensor(
            [self.discount**i for i in range(self.n_steps)])

        # cuda
        if args.use_cuda:
            self.pi.cuda()
            self.pi_t.cuda()
            self.q.cuda()
            self.q_t.cuda()
            self.critic.cuda()
            self.critic_t.cuda()

    def action(self, state):
        """
        Returns action given state
        """
        state = FloatTensor(state.reshape(1, -1))
        action, mu, sigma = self.pi(state)
        # print("mu", mu)
        # print("sigma", sigma)
        return action.cpu().data.numpy().flatten()

    def train(self, memory, n_iter):
        """
        Trains the model for n_iter steps
        """
        for it in range(n_iter):

            # Sample replay buffer
            states, actions, n_states, rewards, steps, dones, stops = memory.sample(
                self.batch_size)
            rewards = self.reward_scale * rewards * self.weights
            rewards = rewards.sum(dim=1, keepdim=True)

            # Select action according to policy pi
            n_actions = self.pi(n_states)[0]

            # Q target = reward + discount * min_i(Qi(next_state, pi(next_state)))
            with torch.no_grad():
                target_q1, target_q2 = self.critic_t(n_states, n_actions)
                target_q = torch.min(target_q1, target_q2)
                target_q = target_q * self.discount**(steps + 1)
                target_q = rewards + (1 - stops) * target_q

            # Get current Q estimates
            current_q1, current_q2 = self.critic(states, actions)

            # Compute critic loss
            critic_loss = nn.MSELoss()(current_q1, target_q) + nn.MSELoss()(
                current_q2, target_q)

            # Optimize the critic
            self.critic_opt.zero_grad()
            critic_loss.backward()
            self.critic_opt.step()

            # E-Step

            # actions, mus, sigmas
            pi_a, pi_mus, pi_sigmas = self.pi(states)
            q_a, q_mus, q_sigmas = self.q(states)

            # KL div between pi and q
            kl_div = torch.log(pi_sigmas**2 / q_sigmas**2)
            kl_div += (q_sigmas**4 + (q_mus - pi_mus)**2) / (2 * pi_sigmas**4)
            kl_div = kl_div.mean(dim=1, keepdim=True)

            # q loss
            loss = self.critic(states, q_a)[0] - kl_div
            loss = -loss.mean()

            # SGD
            self.q_opt.zero_grad()
            self.pi_opt.zero_grad()
            loss.backward()
            self.q_opt.step()
            self.pi_opt.step()

            # M-Step

            # actions, mus, sigmas
            # pi_a, pi_mus, pi_sigmas = self.pi(states)
            # q_a, q_mus, q_sigmas = self.q(states)
            # q_a.detach(), q_mus.detach(), q_sigmas.detach()
            #
            # # KL div between pi and q
            # kl_div = torch.log(pi_sigmas ** 2 / q_sigmas ** 2)
            # kl_div += (q_sigmas ** 4 + (q_mus - pi_mus) ** 2) / \
            #     (2 * pi_sigmas ** 4)
            # kl_div = kl_div.mean(dim=1, keepdim=True)
            #
            # # pi_loss
            # pi_loss = kl_div.mean()
            #
            # # SGD
            # self.pi_opt.zero_grad()
            # pi_loss.backward()
            # torch.nn.utils.clip_grad_norm_(self.pi.parameters(), 1)
            # self.pi_opt.step()

            # print(pi_loss.data, loss.data)

            # Update the frozen actor models
            for param, target_param in zip(self.pi.parameters(),
                                           self.pi_t.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

            # Update the frozen critic models
            for param, target_param in zip(self.critic.parameters(),
                                           self.critic_t.parameters()):
                target_param.data.copy_(self.tau * param.data +
                                        (1 - self.tau) * target_param.data)

    def save(self, directory):
        """
        Save the model in given folder
        """
        self.pi.save_model(directory, "actor")
        self.critic.save_model(directory, "critic")

    def load(self, directory):
        """
        Load model from folder
        """
        self.pi.load_model(directory, "actor")
        self.critic.load_model(directory, "critic")