Exemplo n.º 1
0
class DQN:
    def __init__(self, state_dim, action_dim, cfg):
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.loss = 0
        self.gamma = cfg.gamma
        self.frame_idx = 0 # 用于epsilon的衰减计数
        self.epsilon = lambda frame_idx: cfg.epsilon_end + (cfg.epsilon_start - cfg.epsilon_end) * math.exp(-1. * frame_idx / cfg.epsilon_decay)
        self.batch_size = cfg.batch_size
        self.device = cfg.device
        self.policy_net = MLP(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
        self.target_net = MLP(state_dim, action_dim, cfg.hidden_dim).to(cfg.device)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
        self.memory = ReplayBuffer(cfg.memory_capacity)

    def choose_action(self, state):
        '''policy_net负责与环境进行互动并产生相关动作存放到经验池中,因为后边会采样经验池中的数据来重新生成相关Q值,所以此处不进行梯度的更新'''
        self.frame_idx += 1
        if random.random() > self.epsilon(self.frame_idx):
            with torch.no_grad(): # 使用该语句,使policy_net网络不会进行更新
                state = torch.tensor([state], device=self.device, dtype=torch.float32)
                q_value = self.policy_net(state) 
                action = q_value.max(1)[1].item() # tensor.max(1)[1]返回最大值对应的下标,即action
        else:
            action = random.randrange(self.action_dim)
        return action

    def update(self):
        if len(self.memory) < self.batch_size:
            return
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(self.batch_size)
        ''' 转换为Tensor
        '''
        state_batch = torch.tensor(state_batch, device=self.device, dtype=torch.float)
        action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1)
        reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float)
        next_state_batch = torch.tensor(next_state_batch, device=self.device, dtype=torch.float)
        done_batch = torch.tensor(np.float32(done_batch), device=self.device)

        # 计算当前(s_t, a)对应的Q值,此处的Q值用来训练,所以要求计算梯度;
        # 其实也可以在choose_action时将q值存到经验池中,就可以不同进行下一步的计算了
        q_values = self.policy_net(state_batch).gather(dim=1, index=action_batch)

        # 计算s_t+1状态下target_net网络的最大值
        next_q_values = self.target_net(next_state_batch).max(1)[0].detach() # 由target_net输出的值不会参与到梯度的计算中
        # 对于终止状态,此时done_batch[0]=1, 对应的expected_q_value等于reward
        expected_q_values = reward_batch + self.gamma * next_q_values * (1-done_batch)
        self.loss = nn.MSELoss()(q_values, expected_q_values.unsqueeze(1))
        self.optimizer.zero_grad()
        self.loss.backward()
        self.optimizer.step()

    def save(self, path):
        torch.save(self.target_net.state_dict(), path+'DQN_CheckPoint.pth')

    def load(self, path):
        self.target_net.load_state_dict(torch.load(path+'DQN_CheckPoint.pth'))
Exemplo n.º 2
0
class DoubleDQN:
    def __init__(self, state_dim, action_dim, cfg):

        self.action_dim = action_dim  # 总的动作个数
        self.device = cfg.device  # 设备,cpu或gpu等
        self.gamma = cfg.gamma
        # e-greedy策略相关参数
        self.actions_count = 0
        self.epsilon_start = cfg.epsilon_start
        self.epsilon_end = cfg.epsilon_end
        self.epsilon_decay = cfg.epsilon_decay
        self.batch_size = cfg.batch_size
        self.policy_net = MLP(state_dim, action_dim,
                              hidden_dim=cfg.hidden_dim).to(self.device)
        self.target_net = MLP(state_dim, action_dim,
                              hidden_dim=cfg.hidden_dim).to(self.device)
        # target_net的初始模型参数完全复制policy_net
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()  # 不启用 BatchNormalization 和 Dropout
        # 可查parameters()与state_dict()的区别,前者require_grad=True
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
        self.loss = 0
        self.memory = ReplayBuffer(cfg.memory_capacity)

    def predict(self, state):
        with torch.no_grad():
            # 先转为张量便于丢给神经网络,state元素数据原本为float64
            # 注意state=torch.tensor(state).unsqueeze(0)跟state=torch.tensor([state])等价
            state = torch.tensor([state],
                                 device=self.device,
                                 dtype=torch.float32)
            # 如tensor([[-0.0798, -0.0079]], grad_fn=<AddmmBackward>)
            q_value = self.policy_net(state)
            # tensor.max(1)返回每行的最大值以及对应的下标,
            # 如torch.return_types.max(values=tensor([10.3587]),indices=tensor([0]))
            # 所以tensor.max(1)[1]返回最大值对应的下标,即action
            action = q_value.max(1)[1].item()
        return action

    def choose_action(self, state):
        '''选择动作
        '''
        self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
            math.exp(-1. * self.actions_count / self.epsilon_decay)
        self.actions_count += 1
        if random.random() > self.epsilon:
            action = self.predict(state)
        else:
            action = random.randrange(self.action_dim)
        return action

    def update(self):

        if len(self.memory) < self.batch_size:
            return
        # 从memory中随机采样transition
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
            self.batch_size)
        ### 转为张量 ###
        state_batch = torch.tensor(state_batch,
                                   device=self.device,
                                   dtype=torch.float)
        action_batch = torch.tensor(action_batch,
                                    device=self.device).unsqueeze(
                                        1)  # 例如tensor([[1],...,[0]])
        reward_batch = torch.tensor(
            reward_batch, device=self.device,
            dtype=torch.float)  # tensor([1., 1.,...,1])
        next_state_batch = torch.tensor(next_state_batch,
                                        device=self.device,
                                        dtype=torch.float)

        done_batch = torch.tensor(np.float32(done_batch),
                                  device=self.device).unsqueeze(
                                      1)  # 将bool转为float然后转为张量

        # 计算当前(s_t,a)对应的Q(s_t, a)
        q_values = self.policy_net(state_batch)
        next_q_values = self.policy_net(next_state_batch)
        # 代入当前选择的action,得到Q(s_t|a=a_t)
        q_value = q_values.gather(dim=1, index=action_batch)
        '''以下是Nature DQN的q_target计算方式
        # 计算所有next states的Q'(s_{t+1})的最大值,Q'为目标网络的q函数
        next_q_state_value = self.target_net(
            next_state_batch).max(1)[0].detach()  # 比如tensor([ 0.0060, -0.0171,...,])
        # 计算 q_target
        # 对于终止状态,此时done_batch[0]=1, 对应的expected_q_value等于reward
        q_target = reward_batch + self.gamma * next_q_state_value * (1-done_batch[0])
        '''
        '''以下是Double DQN q_target计算方式,与NatureDQN稍有不同'''
        next_target_values = self.target_net(next_state_batch)
        # 选出Q(s_t‘, a)对应的action,代入到next_target_values获得target net对应的next_q_value,即Q’(s_t|a=argmax Q(s_t‘, a))
        next_target_q_value = next_target_values.gather(
            1,
            torch.max(next_q_values, 1)[1].unsqueeze(1)).squeeze(1)
        q_target = reward_batch + self.gamma * next_target_q_value * (
            1 - done_batch[0])
        self.loss = nn.MSELoss()(q_value, q_target.unsqueeze(1))  # 计算 均方误差loss
        # 优化模型
        self.optimizer.zero_grad(
        )  # zero_grad清除上一步所有旧的gradients from the last step
        # loss.backward()使用backpropagation计算loss相对于所有parameters(需要gradients)的微分
        self.loss.backward()
        for param in self.policy_net.parameters():  # clip防止梯度爆炸
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()  # 更新模型

    def save(self, path):
        torch.save(self.target_net.state_dict(), path + 'checkpoint.pth')

    def load(self, path):
        self.target_net.load_state_dict(torch.load(path + 'checkpoint.pth'))
        for target_param, param in zip(self.target_net.parameters(),
                                       self.policy_net.parameters()):
            param.data.copy_(target_param.data)
Exemplo n.º 3
0
class HierarchicalDQN:
    def __init__(self, state_dim, action_dim, cfg):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = cfg.gamma
        self.device = cfg.device
        self.batch_size = cfg.batch_size
        self.frame_idx = 0
        self.epsilon = lambda frame_idx: cfg.epsilon_end + (
            cfg.epsilon_start - cfg.epsilon_end) * math.exp(-1. * frame_idx /
                                                            cfg.epsilon_decay)
        self.policy_net = MLP(2 * state_dim, action_dim,
                              cfg.hidden_dim).to(self.device)
        self.meta_policy_net = MLP(state_dim, state_dim,
                                   cfg.hidden_dim).to(self.device)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
        self.meta_optimizer = optim.Adam(self.meta_policy_net.parameters(),
                                         lr=cfg.lr)
        self.memory = ReplayBuffer(cfg.memory_capacity)
        self.meta_memory = ReplayBuffer(cfg.memory_capacity)
        self.loss_numpy = 0
        self.meta_loss_numpy = 0
        self.losses = []
        self.meta_losses = []

    def to_onehot(self, x):
        oh = np.zeros(self.state_dim)
        oh[x - 1] = 1.
        return oh

    def set_goal(self, state):
        if random.random() > self.epsilon(self.frame_idx):
            with torch.no_grad():
                state = torch.tensor(state,
                                     device=self.device,
                                     dtype=torch.float32).unsqueeze(0)
                goal = self.meta_policy_net(state).max(1)[1].item()
        else:
            goal = random.randrange(self.state_dim)
        return goal

    def choose_action(self, state):
        self.frame_idx += 1
        if random.random() > self.epsilon(self.frame_idx):
            with torch.no_grad():
                state = torch.tensor(state,
                                     device=self.device,
                                     dtype=torch.float32).unsqueeze(0)
                q_value = self.policy_net(state)
                action = q_value.max(1)[1].item()
        else:
            action = random.randrange(self.action_dim)
        return action

    def update(self):
        self.update_policy()
        self.update_meta()

    def update_policy(self):
        if self.batch_size > len(self.memory):
            return
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
            self.batch_size)
        state_batch = torch.tensor(state_batch, dtype=torch.float)
        action_batch = torch.tensor(action_batch,
                                    dtype=torch.int64).unsqueeze(1)
        reward_batch = torch.tensor(reward_batch, dtype=torch.float)
        next_state_batch = torch.tensor(next_state_batch, dtype=torch.float)
        done_batch = torch.tensor(np.float32(done_batch))
        q_values = self.policy_net(state_batch).gather(
            dim=1, index=action_batch).squeeze(1)
        next_state_values = self.policy_net(next_state_batch).max(
            1)[0].detach()
        expected_q_values = reward_batch + 0.99 * next_state_values * (
            1 - done_batch)
        loss = nn.MSELoss()(q_values, expected_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.policy_net.parameters():  # clip防止梯度爆炸
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()
        self.loss_numpy = loss.detach().numpy()
        self.losses.append(self.loss_numpy)

    def update_meta(self):
        if self.batch_size > len(self.meta_memory):
            return
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.meta_memory.sample(
            self.batch_size)
        state_batch = torch.tensor(state_batch, dtype=torch.float)
        action_batch = torch.tensor(action_batch,
                                    dtype=torch.int64).unsqueeze(1)
        reward_batch = torch.tensor(reward_batch, dtype=torch.float)
        next_state_batch = torch.tensor(next_state_batch, dtype=torch.float)
        done_batch = torch.tensor(np.float32(done_batch))
        q_values = self.meta_policy_net(state_batch).gather(
            dim=1, index=action_batch).squeeze(1)
        next_state_values = self.meta_policy_net(next_state_batch).max(
            1)[0].detach()
        expected_q_values = reward_batch + 0.99 * next_state_values * (
            1 - done_batch)
        meta_loss = nn.MSELoss()(q_values, expected_q_values)
        self.meta_optimizer.zero_grad()
        meta_loss.backward()
        for param in self.meta_policy_net.parameters():  # clip防止梯度爆炸
            param.grad.data.clamp_(-1, 1)
        self.meta_optimizer.step()
        self.meta_loss_numpy = meta_loss.detach().numpy()
        self.meta_losses.append(self.meta_loss_numpy)

    def save(self, path):
        torch.save(self.policy_net.state_dict(),
                   path + 'policy_checkpoint.pth')
        torch.save(self.meta_policy_net.state_dict(),
                   path + 'meta_checkpoint.pth')

    def load(self, path):
        self.policy_net.load_state_dict(
            torch.load(path + 'policy_checkpoint.pth'))
        self.meta_policy_net.load_state_dict(
            torch.load(path + 'meta_checkpoint.pth'))
Exemplo n.º 4
0
class DQN:
    def __init__(self, state_dim, action_dim, cfg):

        self.action_dim = action_dim  # 总的动作个数
        self.device = cfg.device  # 设备,cpu或gpu等
        self.gamma = cfg.gamma  # 奖励的折扣因子
        # e-greedy策略相关参数
        self.sample_count = 0  # 用于epsilon的衰减计数
        self.epsilon = 0
        self.epsilon_start = cfg.epsilon_start
        self.epsilon_end = cfg.epsilon_end
        self.epsilon_decay = cfg.epsilon_decay
        self.batch_size = cfg.batch_size
        self.policy_net = MLP(state_dim, action_dim,
                              hidden_dim=cfg.hidden_dim).to(self.device)
        self.target_net = MLP(state_dim, action_dim,
                              hidden_dim=cfg.hidden_dim).to(self.device)
        # target_net的初始模型参数完全复制policy_net
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()  # 不启用 BatchNormalization 和 Dropout
        # 可查parameters()与state_dict()的区别,前者require_grad=True
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
        self.loss = 0
        self.memory = ReplayBuffer(cfg.memory_capacity)

    def choose_action(self, state, train=True):
        '''选择动作
        '''
        if train:
            self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
                math.exp(-1. * self.sample_count / self.epsilon_decay)
            self.sample_count += 1
            if random.random() > self.epsilon:
                with torch.no_grad():
                    # 先转为张量便于丢给神经网络,state元素数据原本为float64
                    # 注意state=torch.tensor(state).unsqueeze(0)跟state=torch.tensor([state])等价
                    state = torch.tensor([state],
                                         device=self.device,
                                         dtype=torch.float32)
                    # 如tensor([[-0.0798, -0.0079]], grad_fn=<AddmmBackward>)
                    q_value = self.policy_net(state)
                    # tensor.max(1)返回每行的最大值以及对应的下标,
                    # 如torch.return_types.max(values=tensor([10.3587]),indices=tensor([0]))
                    # 所以tensor.max(1)[1]返回最大值对应的下标,即action
                    action = q_value.max(1)[1].item()
            else:
                action = random.randrange(self.action_dim)
            return action
        else:
            with torch.no_grad():  # 取消保存梯度
                # 先转为张量便于丢给神经网络,state元素数据原本为float64
                # 注意state=torch.tensor(state).unsqueeze(0)跟state=torch.tensor([state])等价
                state = torch.tensor(
                    [state], device='cpu', dtype=torch.float32
                )  # 如tensor([[-0.0798, -0.0079]], grad_fn=<AddmmBackward>)
                q_value = self.target_net(state)
                # tensor.max(1)返回每行的最大值以及对应的下标,
                # 如torch.return_types.max(values=tensor([10.3587]),indices=tensor([0]))
                # 所以tensor.max(1)[1]返回最大值对应的下标,即action
                action = q_value.max(1)[1].item()
            return action

    def update(self):

        if len(self.memory) < self.batch_size:
            return
        # 从memory中随机采样transition
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
            self.batch_size)
        '''转为张量
        例如tensor([[-4.5543e-02, -2.3910e-01,  1.8344e-02,  2.3158e-01],...,[-1.8615e-02, -2.3921e-01, -1.1791e-02,  2.3400e-01]])'''
        state_batch = torch.tensor(state_batch,
                                   device=self.device,
                                   dtype=torch.float)
        action_batch = torch.tensor(action_batch,
                                    device=self.device).unsqueeze(
                                        1)  # 例如tensor([[1],...,[0]])
        reward_batch = torch.tensor(
            reward_batch, device=self.device,
            dtype=torch.float)  # tensor([1., 1.,...,1])
        next_state_batch = torch.tensor(next_state_batch,
                                        device=self.device,
                                        dtype=torch.float)
        done_batch = torch.tensor(np.float32(done_batch),
                                  device=self.device).unsqueeze(
                                      1)  # 将bool转为float然后转为张量
        '''计算当前(s_t,a)对应的Q(s_t, a)'''
        '''torch.gather:对于a=torch.Tensor([[1,2],[3,4]]),那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])'''
        q_values = self.policy_net(state_batch).gather(
            dim=1, index=action_batch)  # 等价于self.forward
        # 计算所有next states的V(s_{t+1}),即通过target_net中选取reward最大的对应states
        next_state_values = self.target_net(next_state_batch).max(
            1)[0].detach()  # 比如tensor([ 0.0060, -0.0171,...,])
        # 计算 expected_q_value
        # 对于终止状态,此时done_batch[0]=1, 对应的expected_q_value等于reward
        expected_q_values = reward_batch + self.gamma * \
            next_state_values * (1-done_batch[0])
        # self.loss = F.smooth_l1_loss(q_values,expected_q_values.unsqueeze(1)) # 计算 Huber loss
        self.loss = nn.MSELoss()(q_values,
                                 expected_q_values.unsqueeze(1))  # 计算 均方误差loss
        # 优化模型
        self.optimizer.zero_grad(
        )  # zero_grad清除上一步所有旧的gradients from the last step
        # loss.backward()使用backpropagation计算loss相对于所有parameters(需要gradients)的微分
        self.loss.backward()
        for param in self.policy_net.parameters():  # clip防止梯度爆炸
            param.grad.data.clamp_(-1, 1)

        self.optimizer.step()  # 更新模型

    def save(self, path):
        torch.save(self.target_net.state_dict(), path + 'dqn_checkpoint.pth')

    def load(self, path):
        self.target_net.load_state_dict(torch.load(path +
                                                   'dqn_checkpoint.pth'))
Exemplo n.º 5
0
class DoubleDQN:
    def __init__(self, state_dim, action_dim, cfg):
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.gamma = cfg.gamma
        self.policy_net = MLP(state_dim, action_dim,
                              cfg.hidden_dim).to(cfg.device)
        self.target_net = MLP(state_dim, action_dim,
                              cfg.hidden_dim).to(cfg.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()  # 不启用 BatchNormalization 和 Dropout
        self.optim = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
        self.device = cfg.device
        self.frame_idx = 0
        self.epsilon = lambda frame_idx: cfg.epsilon_end + (
            cfg.epsilon_start - cfg.epsilon_end) * math.exp(-1. * frame_idx /
                                                            cfg.epsilon_decay)
        self.memory = ReplayBuffer(cfg.memory_capacity)
        self.batch_size = cfg.batch_size
        self.loss = 0

    def choose_action(self, state):
        self.frame_idx += 1
        state = torch.tensor([state], device=self.device, dtype=torch.float32)
        if random.random() > self.epsilon(self.frame_idx):
            with torch.no_grad():  # 此处不进行梯度传播
                q_values = self.policy_net(state)
                action = q_values.max(1)[1].item()
        else:
            action = random.randrange(self.action_dim)
        return action

    def update(self):
        if len(self.memory) < self.batch_size:
            return
        # 抽样数据
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
            self.batch_size)

        # 将数据转换为Tensor并推送到GPU
        state_batch = torch.tensor(state_batch,
                                   device=self.device,
                                   dtype=torch.float32)
        action_batch = torch.tensor(action_batch,
                                    device=self.device).unsqueeze(1)
        reward_batch = torch.tensor(reward_batch, device=self.device)
        next_state_batch = torch.tensor(next_state_batch,
                                        device=self.device,
                                        dtype=torch.float32)
        done_batch = torch.tensor(done_batch, device=self.device)

        # 产生(s_t,a)下的q值
        q_values = self.policy_net(state_batch).gather(dim=1,
                                                       index=action_batch)

        # 计算next_q_values
        next_action_batch = self.policy_net(next_state_batch).max(
            1)[1].unsqueeze(1)  # 此处就是DoubleDQN的关键,动作的选取是通过policy_net的
        next_q_values = self.target_net(next_state_batch).gather(
            dim=1,
            index=next_action_batch).detach().squeeze(1)  # q值是target_net输出的
        expected_q_values = reward_batch + self.gamma * next_q_values * (
            ~done_batch)

        self.loss = nn.MSELoss()(q_values, expected_q_values.unsqueeze(1))
        self.optim.zero_grad()
        self.loss.backward()
        for param in self.policy_net.parameters():  # clip防止梯度爆炸
            param.grad.data.clamp_(-1, 1)
        self.optim.step()

    def save(self, path):
        torch.save(self.target_net.state_dict(), path + 'DQN_CheckPoint.pth')

    def load(self, path):
        self.target_net.load_state_dict(torch.load(path +
                                                   'DQN_CheckPoint.pth'))
Exemplo n.º 6
0
class DQN:
    def __init__(self, state_dim, action_dim, cfg):

        self.action_dim = action_dim  # 总的动作个数
        self.device = cfg.device  # 设备,cpu或gpu等
        self.gamma = cfg.gamma  # 奖励的折扣因子
        # e-greedy策略相关参数
        self.frame_idx = 0  # 用于epsilon的衰减计数
        self.epsilon = lambda frame_idx: cfg.epsilon_end + \
            (cfg.epsilon_start - cfg.epsilon_end) * \
            math.exp(-1. * frame_idx / cfg.epsilon_decay)
        self.batch_size = cfg.batch_size
        self.policy_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
        self.target_net = MLP(state_dim, action_dim,hidden_dim=cfg.hidden_dim).to(self.device)
        for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # copy params from policy net
            target_param.data.copy_(param.data)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr)
        self.memory = ReplayBuffer(cfg.memory_capacity)
        

    def choose_action(self, state):
        '''选择动作
        '''
        self.frame_idx += 1
        if random.random() > self.epsilon(self.frame_idx):
            action = self.predict(state)
        else:
            action = random.randrange(self.action_dim)
        return action
    def predict(self,state):
        with torch.no_grad():
            state = torch.tensor([state], device=self.device, dtype=torch.float32)
            q_values = self.policy_net(state)
            action = q_values.max(1)[1].item()
        return action
    def update(self):

        if len(self.memory) < self.batch_size:
            return
        # 从memory中随机采样transition
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(
            self.batch_size)
        '''转为张量
        例如tensor([[-4.5543e-02, -2.3910e-01,  1.8344e-02,  2.3158e-01],...,[-1.8615e-02, -2.3921e-01, -1.1791e-02,  2.3400e-01]])'''
        state_batch = torch.tensor(
            state_batch, device=self.device, dtype=torch.float)
        action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(
            1)  # 例如tensor([[1],...,[0]])
        reward_batch = torch.tensor(
            reward_batch, device=self.device, dtype=torch.float)  # tensor([1., 1.,...,1])
        next_state_batch = torch.tensor(
            next_state_batch, device=self.device, dtype=torch.float)
        done_batch = torch.tensor(np.float32(
            done_batch), device=self.device)

        '''计算当前(s_t,a)对应的Q(s_t, a)'''
        '''torch.gather:对于a=torch.Tensor([[1,2],[3,4]]),那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])'''
        q_values = self.policy_net(state_batch).gather(
            dim=1, index=action_batch)  # 等价于self.forward
        # 计算所有next states的V(s_{t+1}),即通过target_net中选取reward最大的对应states
        next_q_values = self.target_net(next_state_batch).max(
            1)[0].detach()  # 比如tensor([ 0.0060, -0.0171,...,])
        # 计算 expected_q_value
        # 对于终止状态,此时done_batch[0]=1, 对应的expected_q_value等于reward
        expected_q_values = reward_batch + \
            self.gamma * next_q_values * (1-done_batch)
        # self.loss = F.smooth_l1_loss(q_values,expected_q_values.unsqueeze(1)) # 计算 Huber loss
        loss = nn.MSELoss()(q_values, expected_q_values.unsqueeze(1))  # 计算 均方误差loss
        # 优化模型
        self.optimizer.zero_grad()  # zero_grad清除上一步所有旧的gradients from the last step
        # loss.backward()使用backpropagation计算loss相对于所有parameters(需要gradients)的微分
        loss.backward()
        # for param in self.policy_net.parameters():  # clip防止梯度爆炸
        #     param.grad.data.clamp_(-1, 1)
        self.optimizer.step()  # 更新模型

    def save(self, path):
        torch.save(self.target_net.state_dict(), path+'dqn_checkpoint.pth')

    def load(self, path):
        self.target_net.load_state_dict(torch.load(path+'dqn_checkpoint.pth'))
        for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
            param.data.copy_(target_param.data)