예제 #1
0
class PriorDQN(Trainer):

    def __init__(self, parameters):
        super(PriorDQN, self).__init__(parameters)

        self.replay_buffer = PrioritizedReplayBuffer(
            self.buffersize, parameters["alpha"])

        self.beta_start = parameters["beta_start"]
        self.beta_frames = parameters["beta_frames"]

    def push_to_buffer(self, state, action, reward, next_state, done):
        self.replay_buffer.push(state, action, reward, next_state, done)

    def beta_by_frame(self, frame_idx):
        beta = self.beta_start + frame_idx * \
            (1.0 - self.beta_start) / self.beta_frames
        return min(1.0, beta)

    def compute_td_loss(self, batch_size, frame_idx):

        beta = self.beta_by_frame(frame_idx)

        if len(self.replay_buffer) < batch_size:
            return None

        state, action, reward, next_state, done, indices, weights = self.replay_buffer.sample(
            batch_size, beta)

        state = Variable(torch.FloatTensor(np.float32(state)))
        next_state = Variable(torch.FloatTensor(np.float32(next_state)))
        action = Variable(torch.LongTensor(action))
        reward = Variable(torch.FloatTensor(reward))
        done = Variable(torch.FloatTensor(done))
        weights = Variable(torch.FloatTensor(weights))

        q_values = self.current_model(state)
        q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)

        next_q_values = self.current_model(next_state)
        next_q_state_values = self.target_model(next_state)
        next_q_value = next_q_state_values.gather(
            1, torch.max(next_q_values, 1)[1].unsqueeze(1)).squeeze(1)

        expected_q_value = reward + self.gamma * next_q_value * (1 - done)

        loss = (q_value - Variable(expected_q_value.data)).pow(2) * weights
        loss[loss.gt(1)] = 1
        prios = loss + 1e-5
        loss = loss.mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.replay_buffer.update_priorities(indices, prios.data.cpu().numpy())
        
        return loss
예제 #2
0
def learn(env,
          num_actions=3,
          lr=5e-4,
          max_timesteps=100000,
          buffer_size=50000,
          exploration_fraction=0.1,
          exploration_final_eps=0.02,
          train_freq=1,
          batch_size=32,
          print_freq=1,
          checkpoint_freq=10000,
          learning_starts=1000,
          gamma=1.0,
          target_network_update_freq=500,
          prioritized_replay=False,
          prioritized_replay_alpha=0.6,
          prioritized_replay_beta0=0.4,
          prioritized_replay_beta_iters=None,
          prioritized_replay_eps=1e-6,
          num_cpu=16):
    torch.set_num_threads(num_cpu)
    if prioritized_replay:
        replay_buffer = PrioritizedReplayBuffer(
            buffer_size, alpha=prioritized_replay_alpha)
        if prioritized_replay_beta_iters is None:
            prioritized_replay_beta_iters = max_timesteps
        beta_schedule = LinearSchedule(
            prioritized_replay_beta_iters,
            initial_p=prioritized_replay_beta0,
            final_p=1.0)
    else:
        replay_buffer = ReplayBuffer(buffer_size)
        beta_schedule = None
    exploration = LinearSchedule(
        schedule_timesteps=int(exploration_fraction * max_timesteps),
        initial_p=1.0,
        final_p=exploration_final_eps)
    episode_rewards = [0.0]
    saved_mean_reward = None
    obs = env.reset()
    player_relative = obs[0].observation["screen"][_PLAYER_RELATIVE]

    screen = player_relative

    obs, xy_per_marine = common.init(env, obs)

    group_id = 0
    reset = True
    dqn = DQN(num_actions, lr, cuda)

    print('\nCollecting experience...')
    checkpoint_path = 'models/deepq/checkpoint.pth.tar'
    if os.path.exists(checkpoint_path):
        dqn, saved_mean_reward = load_checkpoint(dqn, cuda, filename=checkpoint_path)
    for t in range(max_timesteps):
        # Take action and update exploration to the newest value
        # custom process for DefeatZerglingsAndBanelings
        obs, screen, player = common.select_marine(env, obs)
        # action = act(
        #     np.array(screen)[None], update_eps=update_eps, **kwargs)[0]
        action = dqn.choose_action(np.array(screen)[None])
        reset = False
        rew = 0
        new_action = None
        obs, new_action = common.marine_action(env, obs, player, action)
        army_count = env._obs[0].observation.player_common.army_count
        try:
            if army_count > 0 and _ATTACK_SCREEN in obs[0].observation["available_actions"]:
                obs = env.step(actions=new_action)
            else:
                new_action = [sc2_actions.FunctionCall(_NO_OP, [])]
                obs = env.step(actions=new_action)
        except Exception as e:
            # print(e)
            1  # Do nothing
        player_relative = obs[0].observation["screen"][_PLAYER_RELATIVE]
        new_screen = player_relative
        rew += obs[0].reward
        done = obs[0].step_type == environment.StepType.LAST
        selected = obs[0].observation["screen"][_SELECTED]
        player_y, player_x = (selected == _PLAYER_FRIENDLY).nonzero()
        if len(player_y) > 0:
            player = [int(player_x.mean()), int(player_y.mean())]
        if len(player) == 2:
            if player[0] > 32:
                new_screen = common.shift(LEFT, player[0] - 32, new_screen)
            elif player[0] < 32:
                new_screen = common.shift(RIGHT, 32 - player[0],
                                          new_screen)
            if player[1] > 32:
                new_screen = common.shift(UP, player[1] - 32, new_screen)
            elif player[1] < 32:
                new_screen = common.shift(DOWN, 32 - player[1], new_screen)
        # Store transition in the replay buffer.
        replay_buffer.add(screen, action, rew, new_screen, float(done))
        screen = new_screen
        episode_rewards[-1] += rew
        reward = episode_rewards[-1]
        if done:
            print("Episode Reward : %s" % episode_rewards[-1])
            obs = env.reset()
            player_relative = obs[0].observation["screen"][
                _PLAYER_RELATIVE]
            screen = player_relative
            group_list = common.init(env, obs)
            # Select all marines first
            # env.step(actions=[sc2_actions.FunctionCall(_SELECT_UNIT, [_SELECT_ALL])])
            episode_rewards.append(0.0)
            reset = True

        if t > learning_starts and t % train_freq == 0:
            # Minimize the error in Bellman's equation on a batch sampled from replay buffer.
            if prioritized_replay:
                experience = replay_buffer.sample(
                    batch_size, beta=beta_schedule.value(t))
                (obses_t, actions, rewards, obses_tp1, dones, weights,
                 batch_idxes) = experience
            else:
                obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(
                    batch_size)
                weights, batch_idxes = np.ones_like(rewards), None

            td_errors = dqn.learn(obses_t, actions, rewards, obses_tp1, gamma, batch_size)

            if prioritized_replay:
                new_priorities = np.abs(td_errors) + prioritized_replay_eps
                replay_buffer.update_priorities(batch_idxes,
                                                new_priorities)

        if t > learning_starts and t % target_network_update_freq == 0:
            # Update target network periodically.
            dqn.update_target()

        mean_100ep_reward = round(np.mean(episode_rewards[-101:-1]), 1)
        num_episodes = len(episode_rewards)
        if done and print_freq is not None and len(
                episode_rewards) % print_freq == 0:
            logger.record_tabular("steps", t)
            logger.record_tabular("episodes", num_episodes)
            logger.record_tabular("reward", reward)
            logger.record_tabular("mean 100 episode reward",
                                  mean_100ep_reward)
            logger.record_tabular("% time spent exploring",
                                  int(100 * exploration.value(t)))
            logger.dump_tabular()

        if (checkpoint_freq is not None and t > learning_starts
                and num_episodes > 100 and t % checkpoint_freq == 0):
            if saved_mean_reward is None or mean_100ep_reward > saved_mean_reward:
                if print_freq is not None:
                    logger.log(
                        "Saving model due to mean reward increase: {} -> {}".format(
                            saved_mean_reward,
                            mean_100ep_reward))
                save_checkpoint({
                    'epoch': t + 1,
                    'state_dict': dqn.save_state_dict(),
                    'best_accuracy': mean_100ep_reward
                }, checkpoint_path)
                saved_mean_reward = mean_100ep_reward
예제 #3
0
class DQN:
    def __init__(self, config):
        self.writer = SummaryWriter() 
        self.device = 'cuda' if T.cuda.is_available() else 'cpu'

        self.dqn_type = config["dqn-type"]
        self.run_title = config["run-title"]
        self.env = gym.make(config["environment"])

        self.num_states  = np.prod(self.env.observation_space.shape)
        self.num_actions = self.env.action_space.n

        layers = [
            self.num_states, 
            *config["architecture"], 
            self.num_actions
        ]

        self.policy_net = Q_Network(self.dqn_type, layers).to(self.device)
        self.target_net = Q_Network(self.dqn_type, layers).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        capacity = config["max-experiences"]
        self.p_replay_eps = config["p-eps"]
        self.prioritized_replay = config["prioritized-replay"]
        self.replay_buffer = PrioritizedReplayBuffer(capacity, config["p-alpha"]) if self.prioritized_replay \
                        else ReplayBuffer(capacity)

        self.beta_scheduler = LinearSchedule(config["episodes"], initial_p=config["p-beta-init"], final_p=1.0)
        self.epsilon_decay = lambda e: max(config["epsilon-min"], e * config["epsilon-decay"])

        self.train_freq = config["train-freq"]
        self.use_soft_update = config["use-soft-update"]
        self.target_update = config["target-update"]
        self.tau = config["tau"]
        self.gamma = config["gamma"]
        self.batch_size = config["batch-size"]
        self.time_step = 0

        self.optim = T.optim.AdamW(self.policy_net.parameters(), lr=config["lr-init"], weight_decay=config["weight-decay"])
        self.lr_scheduler = T.optim.lr_scheduler.StepLR(self.optim, step_size=config["lr-step"], gamma=config["lr-gamma"])
        self.criterion = nn.SmoothL1Loss(reduction="none") # Huber Loss
        self.min_experiences = max(config["min-experiences"], config["batch-size"])

        self.save_path = config["save-path"]

    def act(self, state, epsilon=0):
        """
            Act on environment using epsilon-greedy policy
        """
        if np.random.sample() < epsilon:
            return int(np.random.choice(np.arange(self.num_actions)))
        else:
            self.policy_net.eval()
            return self.policy_net(T.tensor(state, device=self.device).float().unsqueeze(0)).argmax().item()

    def _soft_update(self, tau):
        """
            Polyak averaging: soft update model parameters. 
            θ_target = τ*θ_current + (1 - τ)*θ_target
        """
        for target_param, current_param in zip(self.target_net.parameters(), self.policy_net.parameters()):
            target_param.data.copy_(tau*target_param.data + (1.0-tau)*current_param.data)

    def update_target(self, tau):
        if self.use_soft_update:
            self._soft_update(tau)
        elif self.time_step % self.target_update == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())

    def optimize(self, beta=None):
        if len(self.replay_buffer) < self.min_experiences:
            return None, None 

        self.policy_net.train()

        if self.prioritized_replay:
            transitions, (is_weights, t_idxes) = self.replay_buffer.sample(self.batch_size, beta)
        else:
            transitions = self.replay_buffer.sample(self.batch_size)
            is_weights, t_idxes = np.ones(self.batch_size), None

        # transpose the batch --> transition of batch-arrays
        batch = Transition(*zip(*transitions))
        # compute a mask of non-final states and concatenate the batch elements
        non_final_mask = T.tensor(tuple(map(lambda state: state is not None, batch.next_state)), 
                                                                device=self.device, dtype=T.bool)  
        non_final_next_states = T.cat([T.tensor([state]).float() for state in batch.next_state if state is not None]).to(self.device)

        state_batch  = T.tensor(batch.state,  device=self.device).float()
        action_batch = T.tensor(batch.action, device=self.device).long()
        reward_batch = T.tensor(batch.reward, device=self.device).float()

        state_action_values = self.policy_net(state_batch).gather(1, action_batch.unsqueeze(1))
    
        next_state_values = T.zeros(self.batch_size, device=self.device)
        if self.dqn_type == "vanilla":
            next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()
        else:
            self.policy_net.eval()
            action_next_state = self.policy_net(non_final_next_states).max(1)[1]
            self.policy_net.train()
            next_state_values[non_final_mask] = self.target_net(non_final_next_states).gather(1, action_next_state.unsqueeze(1)).squeeze().detach()

        # compute the expected Q values (RHS of the Bellman equation)
        expected_state_action_values = (next_state_values * self.gamma) + reward_batch
        
        # compute temporal difference error
        td_error = T.abs(state_action_values.squeeze() - expected_state_action_values).detach().cpu().numpy()

        # compute Huber loss
        loss = self.criterion(state_action_values, expected_state_action_values.unsqueeze(1))
        loss = T.mean(loss * T.tensor(is_weights, device=self.device))
      
        # optimize the model
        self.optim.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optim.step()

        return td_error, t_idxes

    def run_episode(self, epsilon, beta):
        total_reward, done = 0, False
        state = self.env.reset()
        while not done:
            # use epsilon-greedy to get an action
            action = self.act(state, epsilon)
            # caching the information of current state
            prev_state = state
            # take action
            state, reward, done, _ = self.env.step(action)
            # accumulate reward
            total_reward += reward
            # store the transition in buffer
            if done: state = None 
            self.replay_buffer.push(prev_state, action, state, reward)
            # optimize model
            if self.time_step % self.train_freq == 0:
                td_error, t_idxes = self.optimize(beta=beta)
                # update priorities 
                if self.prioritized_replay and td_error is not None:
                    self.replay_buffer.update_priorities(t_idxes, td_error + self.p_replay_eps)
            # update target network
            self.update_target(self.tau)
            # increment time-step
            self.time_step += 1

        return total_reward

    def train(self, episodes, epsilon, solved_reward):
        total_rewards = np.zeros(episodes)
        for episode in range(episodes):
            
            # compute beta using linear scheduler
            beta = self.beta_scheduler.value(episode)
            # run episode and get rewards
            reward = self.run_episode(epsilon, beta)
            # exponentially decay epsilon
            epsilon = self.epsilon_decay(epsilon)
            # reduce learning rate by
            self.lr_scheduler.step()

            total_rewards[episode] = reward
            avg_reward = total_rewards[max(0, episode-100):(episode+1)].mean()
            last_lr = self.lr_scheduler.get_last_lr()[0]

            # log into tensorboard
            self.writer.add_scalar(f'dqn-{self.dqn_type}/reward', reward, episode)
            self.writer.add_scalar(f'dqn-{self.dqn_type}/reward_100', avg_reward, episode)
            self.writer.add_scalar(f'dqn-{self.dqn_type}/lr', last_lr, episode)
            self.writer.add_scalar(f'dqn-{self.dqn_type}/epsilon', epsilon, episode)

            print(f"Episode: {episode} | Last 100 Average Reward: {avg_reward:.5f} | Learning Rate: {last_lr:.5E} | Epsilon: {epsilon:.5E}", end='\r')

            if avg_reward > solved_reward:
                break
        
        self.writer.close()

        print(f"Environment solved in {episode} episodes")
        T.save(self.policy_net.state_dict(), os.path.join(self.save_path, f"{self.run_title}.pt"))

    def visualize(self, load_path=None):
        done = False
        state = self.env.reset()

        if load_path is not None:
            self.policy_net.load_state_dict(T.load(load_path, map_location=self.device))
        self.policy_net.eval()
        
        while not done:
            self.env.render()
            action = self.act(state)
            state, _, done, _ = self.env.step(int(action))
            sleep(0.01) 
예제 #4
0
def learn(logger, device, env, number_timesteps, network, optimizer, save_path,
          save_interval, ob_scale, gamma, grad_norm, double_q, param_noise,
          exploration_fraction, exploration_final_eps, batch_size, train_freq,
          learning_starts, target_network_update_freq, buffer_size,
          prioritized_replay, prioritized_replay_alpha,
          prioritized_replay_beta0, atom_num, min_value, max_value):
    """
    Papers:
    Mnih V, Kavukcuoglu K, Silver D, et al. Human-level control through deep
    reinforcement learning[J]. Nature, 2015, 518(7540): 529.
    Hessel M, Modayil J, Van Hasselt H, et al. Rainbow: Combining Improvements
    in Deep Reinforcement Learning[J]. 2017.

    Parameters:
    ----------
    double_q (bool): if True double DQN will be used
    param_noise (bool): whether or not to use parameter space noise
    dueling (bool): if True dueling value estimation will be used
    exploration_fraction (float): fraction of entire training period over which
                                  the exploration rate is annealed
    exploration_final_eps (float): final value of random action probability
    batch_size (int): size of a batched sampled from replay buffer for training
    train_freq (int): update the model every `train_freq` steps
    learning_starts (int): how many steps of the model to collect transitions
                           for before learning starts
    target_network_update_freq (int): update the target network every
                                      `target_network_update_freq` steps
    buffer_size (int): size of the replay buffer
    prioritized_replay (bool): if True prioritized replay buffer will be used.
    prioritized_replay_alpha (float): alpha parameter for prioritized replay
    prioritized_replay_beta0 (float): beta parameter for prioritized replay
    atom_num (int): atom number in distributional RL for atom_num > 1
    min_value (float): min value in distributional RL
    max_value (float): max value in distributional RL

    """

    qnet = network.to(device)
    qtar = deepcopy(qnet)
    if prioritized_replay:
        buffer = PrioritizedReplayBuffer(buffer_size, device,
                                         prioritized_replay_alpha,
                                         prioritized_replay_beta0)
    else:
        buffer = ReplayBuffer(buffer_size, device)
    generator = _generate(device, env, qnet, ob_scale, number_timesteps,
                          param_noise, exploration_fraction,
                          exploration_final_eps, atom_num, min_value,
                          max_value)
    if atom_num > 1:
        delta_z = float(max_value - min_value) / (atom_num - 1)
        z_i = torch.linspace(min_value, max_value, atom_num).to(device)

    infos = {'eplenmean': deque(maxlen=100), 'eprewmean': deque(maxlen=100)}
    start_ts = time.time()
    for n_iter in range(1, number_timesteps + 1):
        if prioritized_replay:
            buffer.beta += (1 - prioritized_replay_beta0) / number_timesteps
        *data, info = generator.__next__()
        buffer.add(*data)
        for k, v in info.items():
            infos[k].append(v)

        # update qnet
        if n_iter > learning_starts and n_iter % train_freq == 0:
            b_o, b_a, b_r, b_o_, b_d, *extra = buffer.sample(batch_size)
            b_o.mul_(ob_scale)
            b_o_.mul_(ob_scale)

            if atom_num == 1:
                with torch.no_grad():
                    if double_q:
                        b_a_ = qnet(b_o_).argmax(1).unsqueeze(1)
                        b_q_ = (1 - b_d) * qtar(b_o_).gather(1, b_a_)
                    else:
                        b_q_ = (1 - b_d) * qtar(b_o_).max(1, keepdim=True)[0]
                b_q = qnet(b_o).gather(1, b_a)
                abs_td_error = (b_q - (b_r + gamma * b_q_)).abs()
                priorities = abs_td_error.detach().cpu().clamp(1e-6).numpy()
                if extra:
                    loss = (extra[0] * huber_loss(abs_td_error)).mean()
                else:
                    loss = huber_loss(abs_td_error).mean()
            else:
                with torch.no_grad():
                    b_dist_ = qtar(b_o_).exp()
                    b_a_ = (b_dist_ * z_i).sum(-1).argmax(1)
                    b_tzj = (gamma * (1 - b_d) * z_i[None, :] + b_r).clamp(
                        min_value, max_value)
                    b_i = (b_tzj - min_value) / delta_z
                    b_l = b_i.floor()
                    b_u = b_i.ceil()
                    b_m = torch.zeros(batch_size, atom_num).to(device)
                    temp = b_dist_[torch.arange(batch_size), b_a_, :]
                    b_m.scatter_add_(1, b_l.long(), temp * (b_u - b_i))
                    b_m.scatter_add_(1, b_u.long(), temp * (b_i - b_l))
                b_q = qnet(b_o)[torch.arange(batch_size), b_a.squeeze(1), :]
                kl_error = -(b_q * b_m).sum(1)
                # use kl error as priorities as proposed by Rainbow
                priorities = kl_error.detach().cpu().clamp(1e-6).numpy()
                loss = kl_error.mean()

            optimizer.zero_grad()
            loss.backward()
            if grad_norm is not None:
                nn.utils.clip_grad_norm_(qnet.parameters(), grad_norm)
            optimizer.step()
            if prioritized_replay:
                buffer.update_priorities(extra[1], priorities)

        # update target net and log
        if n_iter % target_network_update_freq == 0:
            qtar.load_state_dict(qnet.state_dict())
            logger.info('{} Iter {} {}'.format('=' * 10, n_iter, '=' * 10))
            fps = int(n_iter / (time.time() - start_ts))
            logger.info('Total timesteps {} FPS {}'.format(n_iter, fps))
            for k, v in infos.items():
                v = (sum(v) / len(v)) if v else float('nan')
                logger.info('{}: {:.6f}'.format(k, v))
            if n_iter > learning_starts and n_iter % train_freq == 0:
                logger.info('vloss: {:.6f}'.format(loss.item()))

        if save_interval and n_iter % save_interval == 0:
            torch.save(
                [qnet.state_dict(), optimizer.state_dict()],
                os.path.join(save_path, '{}.checkpoint'.format(n_iter)))
예제 #5
0
def learn(device,
          env, seed,
          number_timesteps,
          network, optimizer,
          save_path, save_interval, ob_scale,
          gamma, grad_norm,
          double_q, param_noise,
          exploration_fraction, exploration_final_eps,
          batch_size, train_freq, learning_starts, target_network_update_freq,
          buffer_size, prioritized_replay, prioritized_replay_alpha,
          prioritized_replay_beta0):
    """
    Papers:
    Mnih V, Kavukcuoglu K, Silver D, et al. Human-level control through deep
    reinforcement learning[J]. Nature, 2015, 518(7540): 529.
    Hessel M, Modayil J, Van Hasselt H, et al. Rainbow: Combining Improvements
    in Deep Reinforcement Learning[J]. 2017.

    Parameters:
    ----------
    double_q (bool): if True double DQN will be used
    param_noise (bool): whether or not to use parameter space noise
    dueling (bool): if True dueling value estimation will be used
    exploration_fraction (float): fraction of entire training period over which
                                  the exploration rate is annealed
    exploration_final_eps (float): final value of random action probability
    batch_size (int): size of a batched sampled from replay buffer for training
    train_freq (int): update the model every `train_freq` steps
    learning_starts (int): how many steps of the model to collect transitions
                           for before learning starts
    target_network_update_freq (int): update the target network every
                                      `target_network_update_freq` steps
    buffer_size (int): size of the replay buffer
    prioritized_replay (bool): if True prioritized replay buffer will be used.
    prioritized_replay_alpha (float): alpha parameter for prioritized replay
    prioritized_replay_beta0 (float): beta parameter for prioritized replay

    """
    name = '{}_{}'.format(os.path.split(__file__)[-1][:-3], seed)
    logger = get_logger(name)
    logger.info('Note that Rainbow features supported in current version is '
                'consitent with openai/baselines, which means `Multi-step` and '
                '`Distributional` are missing. Welcome any contributions!')

    qnet = network.to(device)
    qtar = deepcopy(qnet)
    if prioritized_replay:
        buffer = PrioritizedReplayBuffer(buffer_size, device,
                                         prioritized_replay_alpha,
                                         prioritized_replay_beta0)
    else:
        buffer = ReplayBuffer(buffer_size, device)
    generator = _generate(device, env, qnet, ob_scale,
                          number_timesteps, param_noise,
                          exploration_fraction, exploration_final_eps)

    infos = {'eplenmean': deque(maxlen=100), 'eprewmean': deque(maxlen=100)}
    start_ts = time.time()
    for n_iter in range(1, number_timesteps + 1):
        if prioritized_replay:
            buffer.beta += (1 - prioritized_replay_beta0) / number_timesteps
        *data, info = generator.__next__()
        buffer.add(*data)
        for k, v in info.items():
            infos[k].append(v)

        # update qnet
        if n_iter > learning_starts and n_iter % train_freq == 0:
            b_o, b_a, b_r, b_o_, b_d, *extra = buffer.sample(batch_size)
            b_o.mul_(ob_scale)
            b_o_.mul_(ob_scale)
            b_q = qnet(b_o).gather(1, b_a)
            with torch.no_grad():
                if double_q:
                    b_a_ = qnet(b_o_).argmax(1).unsqueeze(1)
                    b_q_ = (1 - b_d) * qtar(b_o_).gather(1, b_a_)
                else:
                    b_q_ = (1 - b_d) * qtar(b_o_).max(1, keepdim=True)[0]
            abs_td_error = (b_q - (b_r + gamma * b_q_)).abs()
            if extra:
                loss = (extra[0] * huber_loss(abs_td_error)).mean()  # weighted
            else:
                loss = huber_loss(abs_td_error).mean()
            optimizer.zero_grad()
            loss.backward()
            if grad_norm is not None:
                nn.utils.clip_grad_norm_(qnet.parameters(), grad_norm)
            optimizer.step()
            if prioritized_replay:
                priorities = abs_td_error.detach().cpu().clamp(1e-6).numpy()
                buffer.update_priorities(extra[1], priorities)

        # update target net and log
        if n_iter % target_network_update_freq == 0:
            qtar.load_state_dict(qnet.state_dict())
            logger.info('{} Iter {} {}'.format('=' * 10, n_iter, '=' * 10))
            fps = int(n_iter / (time.time() - start_ts))
            logger.info('Total timesteps {} FPS {}'.format(n_iter, fps))
            for k, v in infos.items():
                v = (sum(v) / len(v)) if v else float('nan')
                logger.info('{}: {:.6f}'.format(k, v))
            if n_iter > learning_starts and n_iter % train_freq == 0:
                logger.info('vloss: {:.6f}'.format(loss.item()))

        if save_interval and n_iter % save_interval == 0:
            torch.save([qnet.state_dict(), optimizer.state_dict()],
                       os.path.join(save_path, '{}.{}'.format(name, n_iter)))