예제 #1
0
class DDPG(Chain):
    def __init__(self):
        super(DDPG, self).__init__(
            actor=Actor(),
            critic=Critic(),
        )
        self.target_actor = deepcopy(self.actor)
        self.target_critic = deepcopy(self.critic)
        disable_train(self.target_actor)
        disable_train(self.target_critic)

        self.noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(A_DIM))
        self.buffer = ReplayBuffer(BUFFER_SIZE)
        self.time = 0

    def reset(self, s):
        self.prev_s = s
        self.noise.reset()

    def step(self, s, r, done, trainable):
        self.time += 1
        self.buffer.add(self.prev_s, self.prev_a, r, done, s, self.prev_noise)
        self.prev_s = s
        if trainable and self.time % TRAIN_INTERVAL == 0:
            if len(self.buffer) > NUM_WARMUP_STEP:
                return self._update()

    def get_action(self):
        S, = make_batch(self.prev_s)
        a = self.actor(S)[0]  # (A_DIM, )
        noise = self.noise().astype(np.float32)
        self.prev_a = a
        self.prev_noise = noise
        return (a + noise).data.reshape(-1)

    def _update(self):
        S, A, R, D, S2, N = self.buffer.sample_batch(
            BATCH_SIZE)  # (6, BATCH_SIZE)
        S = np.array(S, dtype=np.float32)  # (BATCH_SIZE, O_DIM)
        S2 = np.array(S2, dtype=np.float32)
        A = F.stack(A)  # (BATCH_SIZE, A_DIM)
        R = np.array(R, dtype=np.float32).reshape(-1, 1)
        N = np.array(N)

        # update critic
        A_ = self.target_actor(S2)
        Y = R + GAMMA * self.target_critic(S2, A_.data)
        Q_batch = self.critic(S, (A + N).data)
        critic_loss = F.mean_squared_error(Y.data, Q_batch)
        self.critic.update(critic_loss)

        # update actor
        A = self.actor(S)  # why?? but essential!!
        Q = self.critic(S, A)
        actor_loss = -F.sum(Q) / BATCH_SIZE
        #from chainer import computational_graph as c
        #g = c.build_computational_graph([actor_loss])
        #with open('graph_actorloss.dot', 'w') as o:
        #    o.write(g.dump())
        #exit()
        self.actor.update(actor_loss)

        # update target
        soft_copy_param(self.target_critic, self.critic, TAU)
        soft_copy_param(self.target_actor, self.actor, TAU)

        return actor_loss.data, critic_loss.data
예제 #2
0
class SAC:
    # TODO:
    # scale action
    # load save
    def __init__(
            self,
            env,
            learning_rate: float = 3e-4,
            tau: float = 0.005,
            buffer_size: int = 1e6,
            alpha: Union[float, str] = 'auto',
            net_arch: List = [256, 256],
            batch_size: int = 256,
            num_q_nets: int = 2,
            m_sample: int = None,  # None == SAC, 2 == REDQ
            learning_starts: int = 100,
            gradient_updates: int = 1,
            gamma: float = 0.99,
            mbpo: bool = False,
            dynamics_rollout_len: int = 1,
            rollout_dynamics_starts: int = 5000,
            real_ratio: float = 0.05,
            project_name: str = 'sac',
            experiment_name: Optional[str] = None,
            log: bool = True,
            wandb: bool = True,
            device: Union[th.device, str] = 'auto'):

        self.env = env
        self.observation_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.shape[0]
        self.learning_rate = learning_rate
        self.tau = tau
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.num_q_nets = num_q_nets
        self.m_sample = m_sample
        self.net_arch = net_arch
        self.learning_starts = learning_starts
        self.batch_size = batch_size
        self.gradient_updates = gradient_updates
        self.device = th.device('cuda' if th.cuda.is_available() else 'cpu'
                                ) if device == 'auto' else device
        self.replay_buffer = ReplayBuffer(self.observation_dim,
                                          self.action_dim,
                                          max_size=buffer_size)

        self.q_nets = [
            SoftQNetwork(self.observation_dim + self.action_dim,
                         net_arch=net_arch).to(self.device)
            for _ in range(num_q_nets)
        ]
        self.target_q_nets = [
            SoftQNetwork(self.observation_dim + self.action_dim,
                         net_arch=net_arch).to(self.device)
            for _ in range(num_q_nets)
        ]
        for q_net, target_q_net in zip(self.q_nets, self.target_q_nets):
            target_q_net.load_state_dict(q_net.state_dict())
            for param in target_q_net.parameters():
                param.requires_grad = False

        self.policy = Policy(self.observation_dim,
                             self.action_dim,
                             self.env.action_space,
                             net_arch=net_arch).to(self.device)

        self.target_entropy = -th.prod(th.Tensor(
            self.env.action_space.shape)).item()
        if alpha == 'auto':
            self.log_alpha = th.zeros(1,
                                      requires_grad=True,
                                      device=self.device)
            self.alpha = self.log_alpha.exp().item()
            self.alpha_optim = optim.Adam([self.log_alpha],
                                          lr=self.learning_rate)
        else:
            self.alpha_optim = None
            self.alpha = alpha

        q_net_params = []
        for q_net in self.q_nets:
            q_net_params += list(q_net.parameters())
        self.q_optim = optim.Adam(q_net_params, lr=self.learning_rate)
        self.policy_optim = optim.Adam(list(self.policy.parameters()),
                                       lr=self.learning_rate)

        self.mbpo = mbpo
        if self.mbpo:
            self.dynamics = ProbabilisticEnsemble(
                input_dim=self.observation_dim + self.action_dim,
                output_dim=self.observation_dim + 1,
                device=self.device)
            self.dynamics_buffer = ReplayBuffer(self.observation_dim,
                                                self.action_dim,
                                                max_size=400000)
        self.dynamics_rollout_len = dynamics_rollout_len
        self.rollout_dynamics_starts = rollout_dynamics_starts
        self.real_ratio = real_ratio

        self.experiment_name = experiment_name if experiment_name is not None else f"sac_{int(time.time())}"
        self.log = log
        if self.log:
            self.writer = SummaryWriter(f"runs/{self.experiment_name}")
            if wandb:
                import wandb
                wandb.init(project=project_name,
                           sync_tensorboard=True,
                           config=self.get_config(),
                           name=self.experiment_name,
                           monitor_gym=True,
                           save_code=True)
                self.writer = SummaryWriter(f"/tmp/{self.experiment_name}")

    def get_config(self):
        return {
            'env_id': self.env.unwrapped.spec.id,
            'learning_rate': self.learning_rate,
            'num_q_nets': self.num_q_nets,
            'batch_size': self.batch_size,
            'tau': self.tau,
            'gamma': self.gamma,
            'net_arch': self.net_arch,
            'gradient_updates': self.gradient_updates,
            'm_sample': self.m_sample,
            'buffer_size': self.buffer_size,
            'learning_starts': self.learning_starts,
            'mbpo': self.mbpo,
            'dynamics_rollout_len': self.dynamics_rollout_len
        }

    def save(self, save_replay_buffer=True):
        save_dir = 'weights/'
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        saved_params = {
            'policy_state_dict': self.policy.state_dict(),
            'policy_optimizer_state_dict': self.policy_optim.state_dict(),
            'log_alpha': self.log_alpha,
            'alpha_optimizer_state_dict': self.alpha_optim.state_dict()
        }
        for i, (q_net,
                target_q_net) in enumerate(zip(self.q_nets,
                                               self.target_q_nets)):
            saved_params['q_net_' + str(i) +
                         '_state_dict'] = q_net.state_dict()
            saved_params['target_q_net_' + str(i) +
                         '_state_dict'] = target_q_net.state_dict()
        saved_params['q_nets_optimizer_state_dict'] = self.q_optim.state_dict()

        if save_replay_buffer:
            saved_params['replay_buffer'] = self.replay_buffer

        th.save(saved_params, save_dir + "/" + self.experiment_name + '.tar')

    def load(self, path, load_replay_buffer=True):
        params = th.load(path)
        self.policy.load_state_dict(params['policy_state_dict'])
        self.policy_optim.load_state_dict(
            params['policy_optimizer_state_dict'])
        self.log_alpha = params['log_alpha']
        self.alpha_optim.load_state_dict(params['alpha_optimizer_state_dict'])
        for i, (q_net,
                target_q_net) in enumerate(zip(self.q_nets,
                                               self.target_q_nets)):
            q_net.load_state_dict(params['q_net_' + str(i) + '_state_dict'])
            target_q_net.load_state_dict(params['target_q_net_' + str(i) +
                                                '_state_dict'])
        self.q_optim.load_state_dict(params['q_nets_optimizer_state_dict'])
        if load_replay_buffer and 'replay_buffer' in params:
            self.replay_buffer = params['replay_buffer']

    def sample_batch_experiences(self):
        if not self.mbpo or self.num_timesteps < self.rollout_dynamics_starts:
            return self.replay_buffer.sample(self.batch_size,
                                             to_tensor=True,
                                             device=self.device)
        else:
            num_real_samples = int(self.batch_size *
                                   0.05)  # 5% of real world data
            s_obs, s_actions, s_rewards, s_next_obs, s_dones = self.replay_buffer.sample(
                num_real_samples, to_tensor=True, device=self.device)
            m_obs, m_actions, m_rewards, m_next_obs, m_dones = self.dynamics_buffer.sample(
                self.batch_size - num_real_samples,
                to_tensor=True,
                device=self.device)
            experience_tuples = (th.cat([s_obs, m_obs], dim=0),
                                 th.cat([s_actions, m_actions], dim=0),
                                 th.cat([s_rewards, m_rewards], dim=0),
                                 th.cat([s_next_obs, m_next_obs], dim=0),
                                 th.cat([s_dones, m_dones], dim=0))
            return experience_tuples

    def rollout_dynamics(self):
        # MBPO Planning
        with th.no_grad():
            for _ in range(
                    4
            ):  # 4 samples of 25000 instead of 1 of 100000 to not allocate all gpu memory
                obs = self.replay_buffer.sample_obs(25000,
                                                    to_tensor=True,
                                                    device=self.device)
                fake_env = FakeEnv(self.dynamics, self.env.unwrapped.spec.id)
                for plan_step in range(self.dynamics_rollout_len):
                    actions = self.policy(obs, deterministic=False)

                    next_obs_pred, r_pred, dones, info = fake_env.step(
                        obs, actions)
                    obs, actions = obs.detach().cpu().numpy(), actions.detach(
                    ).cpu().numpy()

                    for i in range(len(obs)):
                        self.dynamics_buffer.add(obs[i], actions[i], r_pred[i],
                                                 next_obs_pred[i], dones[i])

                    nonterm_mask = ~dones.squeeze(-1)
                    if nonterm_mask.sum() == 0:
                        break

                    obs = next_obs_pred[nonterm_mask]

    @property
    def dynamics_train_freq(self):
        if self.num_timesteps < 100000:
            return 250
        else:
            return 1000

    def train(self):
        for _ in range(self.gradient_updates):
            s_obs, s_actions, s_rewards, s_next_obs, s_dones = self.sample_batch_experiences(
            )

            with th.no_grad():
                next_actions, log_probs = self.policy.action_log_prob(
                    s_next_obs)
                q_input = th.cat([s_next_obs, next_actions], dim=1)
                if self.m_sample is not None:  # REDQ sampling
                    q_targets = th.cat([
                        q_target(q_input) for q_target in np.random.choice(
                            self.target_q_nets, self.m_sample, replace=False)
                    ],
                                       dim=1)
                else:
                    q_targets = th.cat(
                        [q_target(q_input) for q_target in self.target_q_nets],
                        dim=1)

                target_q, _ = th.min(q_targets, dim=1, keepdim=True)
                target_q -= self.alpha * log_probs.reshape(-1, 1)
                target_q = s_rewards + (1 - s_dones) * self.gamma * target_q

            sa = th.cat([s_obs, s_actions], dim=1)
            q_values = [q_net(sa) for q_net in self.q_nets]
            critic_loss = (1 / self.num_q_nets) * sum(
                [F.mse_loss(q_value, target_q) for q_value in q_values])

            self.q_optim.zero_grad()
            critic_loss.backward()
            self.q_optim.step()

            # Polyak update
            for q_net, target_q_net in zip(self.q_nets, self.target_q_nets):
                for param, target_param in zip(q_net.parameters(),
                                               target_q_net.parameters()):
                    target_param.data.copy_(self.tau * param.data +
                                            (1 - self.tau) * target_param.data)

        # Policy update
        actions, log_pi = self.policy.action_log_prob(s_obs)
        sa = th.cat([s_obs, actions], dim=1)
        q_values_pi = th.cat([q_net(sa) for q_net in self.q_nets], dim=1)
        if self.m_sample is not None:
            min_q_value_pi = th.mean(q_values_pi, dim=1, keepdim=True)
        else:
            min_q_value_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
        policy_loss = (self.alpha * log_pi - min_q_value_pi).mean()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        # Automatic temperature learning
        if self.alpha_optim is not None:
            alpha_loss = (-self.log_alpha *
                          (log_pi.detach() + self.target_entropy)).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self.alpha = self.log_alpha.exp().item()

        # Log losses
        if self.log and self.num_timesteps % 100 == 0:
            self.writer.add_scalar("losses/critic_loss", critic_loss.item(),
                                   self.num_timesteps)
            self.writer.add_scalar("losses/policy_loss", policy_loss.item(),
                                   self.num_timesteps)
            self.writer.add_scalar("losses/alpha", self.alpha,
                                   self.num_timesteps)
            if self.alpha_optim is not None:
                self.writer.add_scalar("losses/alpha_loss", alpha_loss.item(),
                                       self.num_timesteps)

    def learn(self, total_timesteps):
        episode_reward = 0.0,
        num_episodes = 0
        obs, done = self.env.reset(), False
        self.num_timesteps = 0
        for step in range(1, total_timesteps + 1):
            self.num_timesteps += 1

            if step < self.learning_starts:
                action = self.env.action_space.sample()
            else:
                with th.no_grad():
                    action = self.policy(
                        th.tensor(obs).float().to(
                            self.device)).detach().cpu().numpy()

            next_obs, reward, done, info = self.env.step(action)

            terminal = done if 'TimeLimit.truncated' not in info else not info[
                'TimeLimit.truncated']
            self.replay_buffer.add(obs, action, reward, next_obs, terminal)

            if step >= self.learning_starts:
                if self.mbpo:
                    if self.num_timesteps % self.dynamics_train_freq == 0:
                        m_obs, m_actions, m_rewards, m_next_obs, m_dones = self.replay_buffer.get_all_data(
                        )
                        X = np.hstack((m_obs, m_actions))
                        Y = np.hstack((m_rewards, m_next_obs - m_obs))
                        mean_holdout_loss = self.dynamics.train_ensemble(X, Y)
                        self.writer.add_scalar("dynamics/mean_holdout_loss",
                                               mean_holdout_loss,
                                               self.num_timesteps)

                    if self.num_timesteps >= self.rollout_dynamics_starts and self.num_timesteps % 250 == 0:
                        self.rollout_dynamics()

                self.train()

            episode_reward += reward
            if done:
                obs, done = self.env.reset(), False
                num_episodes += 1

                if num_episodes % 10 == 0:
                    print(
                        f"Episode: {num_episodes} Step: {step}, Ep. Reward: {episode_reward}"
                    )
                if self.log:
                    self.writer.add_scalar("metrics/episode_reward",
                                           episode_reward, self.num_timesteps)

                episode_reward = 0.0
            else:
                obs = next_obs

        if self.log:
            self.writer.close()
        self.env.close()
예제 #3
0
파일: dqn.py 프로젝트: sumeetpathania/DDQN
class DQN():
    def __init__(self, state_size, action_size, action_space, args):
        self.device = torch.device("cuda" if args.cuda else "cpu")
        self.buffer = ReplayBuffer(args.buffer_size, args.batch_size, self.device)

        self.action_size = action_size
        self.gamma = args.gamma
        self.tau = args.tau

        self.eps = EpsilonController(e_decays = args.eps_decays, e_min = args.eps_min)

        self.q_local = QNetwork(state_size, action_size, args.hidden_size).to(self.device)
        self.q_optimizer = optim.Adam(self.q_local.parameters(), lr=args.lr)
        self.q_target = copy.deepcopy(self.q_local)


    def act(self, state, eval = False):
        '''return action given state
        :param state (np.ndarray): state
        :param eval (bool): whether if we are evaluating policy. set to True in utils.traj.py
        :return action (np.ndarray): action with episilon noise if not eval
        '''
        state = torch.FloatTensor(state).to(self.device)
        with torch.no_grad():
            action_values = self.q_local(state)

        if not eval:
            # Epsilon-greedy action selection
            if random.random() > self.eps.val():
                action = np.argmax(action_values.cpu().data.numpy())
            else:
                action = random.choice(np.arange(self.action_size))
        else:
            action = np.argmax(action_values.cpu().data.numpy())
        return action

    def val(self, state, action):
        '''return the estimated Q value of state action pair'''
        state = torch.FloatTensor(state).to(self.device)
        # action = torch.LongTensor(action).to(self.device)
        q_value = self.q_local(state)[action]
        return q_value.item()


    def step(self, state, action, reward, next_state, mask):
        '''step on transition'''
        # transition: (state, action, reward, next_state, mask)
        self.buffer.add(state, action, reward, next_state, mask)
        self.eps.update()

    def update(self):
        '''sample batch of experience tuple and update q network'''

        # Sample replay buffer
        batch = self.buffer.sample(discrete = True)
        states, actions, rewards, next_states, not_done = batch

        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.q_target(next_states).detach().max(1)[0].unsqueeze(1)
        # Compute Q targets for current states
        Q_targets = rewards + (self.gamma * Q_targets_next * not_done)

        # Get expected Q values from local model
        Q_expected = self.q_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.q_optimizer.zero_grad()
        loss.backward()
        self.q_optimizer.step()

        # ------------------- update target Q network ------------------- #
        soft_update(self.q_local, self.q_target, self.tau)

        return loss
예제 #4
0
class Agent():
    def __init__(self, num_agents, state_size, action_size, opts):
        self.num_agents = num_agents
        self.state_size = state_size
        self.action_size = action_size
        self.opts = opts

        # Actor Network
        self.actor_local = ActorNet(state_size,
                                    action_size,
                                    fc1_units=opts.a_fc1,
                                    fc2_units=opts.a_fc2).to(opts.device)
        self.actor_target = ActorNet(state_size,
                                     action_size,
                                     fc1_units=opts.a_fc1,
                                     fc2_units=opts.a_fc2).to(opts.device)
        self.actor_optimizer = torch.optim.Adam(self.actor_local.parameters(),
                                                lr=opts.actor_lr)

        # Critic Network
        self.critic_local = CriticNet(state_size,
                                      action_size,
                                      fc1_units=opts.c_fc1,
                                      fc2_units=opts.c_fc2).to(opts.device)
        self.critic_target = CriticNet(state_size,
                                       action_size,
                                       fc1_units=opts.c_fc1,
                                       fc2_units=opts.c_fc2).to(opts.device)
        self.critic_optimizer = torch.optim.Adam(
            self.critic_local.parameters(),
            lr=opts.critic_lr,
            weight_decay=opts.critic_weight_decay)

        # Noise process
        self.noise = OUNoise((num_agents, action_size), opts.random_seed)
        self.step_idx = 0

        # Replay memory
        self.memory = ReplayBuffer(action_size, opts.buffer_size,
                                   opts.batch_size, opts.random_seed,
                                   opts.device)

    def step(self, state, action, reward, next_state, done):
        for i in range(self.num_agents):
            self.memory.add(state[i, :], action[i, :], reward[i],
                            next_state[i, :], done[i])

        self.step_idx += 1
        is_learn_iteration = (self.step_idx % self.opts.learn_every) == 0
        is_update_iteration = (self.step_idx % self.opts.update_every) == 0

        if len(self.memory) > self.opts.batch_size:
            if is_learn_iteration:
                experiences = self.memory.sample()
                self.learn(experiences, self.opts.gamma)

            if is_update_iteration:
                soft_update(self.critic_local, self.critic_target,
                            self.opts.tau)
                soft_update(self.actor_local, self.actor_target, self.opts.tau)

    def act(self, state):
        state = torch.from_numpy(state).float().to(self.opts.device)

        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state).cpu().data.numpy()
        self.actor_local.train()

        action += self.noise.sample()
        return np.clip(action, self.opts.minimum_action_value,
                       self.opts.maximum_action_value)

    def save(self):
        torch.save(self.critic_local.state_dict(),
                   self.opts.output_data_path + "critic_local.pth")
        torch.save(self.critic_target.state_dict(),
                   self.opts.output_data_path + "critic_target.pth")
        torch.save(self.actor_local.state_dict(),
                   self.opts.output_data_path + "actor_local.pth")
        torch.save(self.actor_target.state_dict(),
                   self.opts.output_data_path + "actor_target.pth")

    def learn(self, experiences, gamma):
        states, actions, rewards, next_states, dones = experiences

        states = tensor(states, self.opts.device)
        actions = tensor(actions, self.opts.device)
        rewards = tensor(rewards, self.opts.device)
        next_states = tensor(next_states, self.opts.device)
        mask = tensor(1 - dones, self.opts.device)

        # Update critic
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next)
        Q_targets = rewards + (gamma * Q_targets_next * mask)

        # Compute & minimize critic loss
        Q_expected = self.critic_local(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Update actor
        actions_pred = self.actor_local(states)

        # Compute & minimize critic loss
        actor_loss = -self.critic_local(states, actions_pred).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
예제 #5
0
class TD3():
    def __init__(self,
                 state_size,
                 action_size,
                 action_space,
                 args,
                 policy_noise=0.2,
                 noise_clip=0.5,
                 policy_freq=2):
        self.device = torch.device("cuda" if args.cuda else "cpu")
        self.buffer = ReplayBuffer(args.buffer_size, args.batch_size,
                                   self.device)

        self.action_size = action_size
        self.gamma = args.gamma
        self.tau = args.tau
        self.start_steps = args.start_steps

        self.total_it = 0
        self.max_action = float(action_space.high[0])
        self.policy_noise = policy_noise * self.max_action  # Target policy smoothing is scaled wrt the action scale
        self.noise_clip = noise_clip * self.max_action
        self.policy_freq = policy_freq
        self.expl_noise = args.expl_noise

        self.ce = nn.CrossEntropyLoss()
        self.mse = nn.MSELoss()

        self.policy = Actor(state_size, action_space.shape[0],
                            args.hidden_size, action_space).to(self.device)
        self.policy_optimizer = optim.Adam(self.policy.parameters(),
                                           lr=args.lr)
        self.policy_target = copy.deepcopy(self.policy)

        self.critic_local = QNetwork(state_size, action_space.shape[0],
                                     args.hidden_size).to(self.device)
        self.critic_optimizer = optim.Adam(self.critic_local.parameters(),
                                           lr=args.lr)
        self.critic_target = copy.deepcopy(self.critic_local)

    def act(self, state, eval=False):
        '''return action given state
        :param state (np.ndarray): state
        :param eval (bool): whether if we are evaluating policy. set to True in utils.traj.py
        :return action (np.ndarray): action with Gaussian noise if not eval
        '''
        state = torch.FloatTensor(state).to(self.device)
        if not eval:
            action = (self.policy(state).detach().cpu().numpy() +
                      np.random.normal(0,
                                       self.max_action * self.expl_noise,
                                       size=self.action_size)).clip(
                                           -self.max_action, self.max_action)
        else:
            action = self.policy(state).detach().cpu().numpy()
        return action

    def val(self, state, action):
        '''return the estimated Q value of state action pair'''
        state = torch.FloatTensor(state).to(self.device)
        action = torch.FloatTensor(action).to(self.device)
        v1, v2 = self.critic_local.forward_one(state, action)
        return v1.item()

    def step(self, state, action, reward, next_state, mask):
        '''step on transition'''
        # transition: (state, action, reward, next_state, mask)
        self.buffer.add(state, action, reward, next_state, mask)

    def update(self):
        '''update critic and policy as in TD3'''
        self.total_it += 1

        # Sample replay buffer
        batch = self.buffer.sample()
        state, action, reward, next_state, not_done = batch

        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (torch.randn_like(action) * self.policy_noise).clamp(
                -self.noise_clip, self.noise_clip)

            next_action = (self.policy_target(next_state) + noise).clamp(
                -self.max_action, self.max_action)

            # Compute the target Q value
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.gamma * target_Q

        # Get current Q estimates
        current_Q1, current_Q2 = self.critic_local(state, action)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
            current_Q2, target_Q)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Delayed policy updates
        if self.total_it % self.policy_freq == 0:

            # Compute actor losse
            actor_loss = -self.critic_local.Q1(state,
                                               self.policy(state)).mean()

            # Optimize the actor
            self.policy_optimizer.zero_grad()
            actor_loss.backward()
            self.policy_optimizer.step()

            # Update the frozen target models
            soft_update(self.critic_local, self.critic_target, self.tau)
            soft_update(self.policy, self.policy_target, self.tau)