Ejemplo n.º 1
0
class DQN():
    def __init__(self,
                 state_dim,
                 action_dim,
                 mem_size=10000,
                 train_batch_size=32,
                 gamma=0.99,
                 lr=1e-3,
                 tau=0.1,
                 if_dueling=False,
                 if_PER=False,
                 load_path=None):
        self.mem_size, self.train_batch_size = mem_size, train_batch_size
        self.gamma, self.lr = gamma, lr
        self.global_step = 0
        self.tau = tau
        self.state_dim, self.action_dim = state_dim, action_dim
        self.if_PER = if_PER
        self.replay_mem = PERMemory(mem_size) if if_PER else SlidingMemory(
            mem_size)
        self.policy_net = DQN_fc_network(state_dim, action_dim, 1)
        self.target_net = DQN_fc_network(state_dim, action_dim, 1)
        self.epsilon, self.min_eps = 0.9, 0.4

        if load_path is not None:
            self.policy_net.load_state_dict(torch.load(load_path))

        if if_dueling:
            self.policy_net = DQN_dueling_network(state_dim, action_dim, 1)
            self.target_net = DQN_dueling_network(state_dim, action_dim, 1)

        self.optimizer = optim.RMSprop(self.policy_net.parameters(), self.lr)
        self.hard_update(self.target_net, self.policy_net)

    def soft_update(self, target, source, tau):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) +
                                    param.data * tau)

    def hard_update(self, target, source):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(param.data)

    #  training process
    def train(self, pre_state, action, reward, next_state, if_end):

        self.replay_mem.add(pre_state, action, reward, next_state, if_end)

        if self.replay_mem.num() < self.mem_size:
            return

        # sample $self.train_batch_size$ samples from the replay memory, and use them to train
        if not self.if_PER:
            train_batch = self.replay_mem.sample(self.train_batch_size)
        else:
            train_batch, idx_batch, weight_batch = self.replay_mem.sample(
                self.train_batch_size)
            weight_batch = torch.tensor(weight_batch,
                                        dtype=torch.float).unsqueeze(1)

        # adjust dtype to suit the gym default dtype
        pre_state_batch = torch.tensor([x[0] for x in train_batch],
                                       dtype=torch.float)
        action_batch = torch.tensor([x[1] for x in train_batch],
                                    dtype=torch.long)  # dtype = long for gater
        reward_batch = torch.tensor([x[2] for x in train_batch],
                                    dtype=torch.float).view(
                                        self.train_batch_size, 1)
        next_state_batch = torch.tensor([x[3] for x in train_batch],
                                        dtype=torch.float)
        if_end = [x[4] for x in train_batch]
        if_end = torch.tensor(np.array(if_end).astype(float),
                              dtype=torch.float).view(self.train_batch_size, 1)

        # use the target_Q_network to get the target_Q_value
        # torch.max[0] gives the max value, torch.max[1] gives the argmax index

        # vanilla dqn
        #q_target_ = self.target_net(next_state_batch).max(1)[0].detach() # detach to not bother the gradient
        #q_target_ = q_target_.view(self.train_batch_size,1)

        # double dqn

        with torch.no_grad():
            next_best_action = self.policy_net(next_state_batch).max(
                1)[1].detach()
            q_target_ = self.target_net(next_state_batch).gather(
                1, next_best_action.unsqueeze(1))
            q_target_ = q_target_.view(self.train_batch_size, 1).detach()

        q_target = self.gamma * q_target_ * (1 - if_end) + reward_batch

        # unsqueeze to make gather happy
        q_pred = self.policy_net(pre_state_batch).gather(
            1, action_batch.unsqueeze(1))

        if self.if_PER:
            TD_error_batch = np.abs(q_target.numpy() - q_pred.detach().numpy())
            self.replay_mem.update(idx_batch, TD_error_batch)

        self.optimizer.zero_grad()

        loss = (q_pred - q_target)**2
        if self.if_PER:
            loss *= weight_batch

        loss = torch.mean(loss)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1)
        self.optimizer.step()

        # update target network
        self.soft_update(self.target_net, self.policy_net, self.tau)

        self.epsilon = max(self.epsilon * 0.99995, 0.22)

    # store the (pre_s, action, reward, next_state, if_end) tuples in the replay memory
    def perceive(self, pre_s, action, reward, next_state, if_end):
        self.replay_mem.append([pre_s, action, reward, next_state, if_end])
        if len(self.replay_mem) > self.mem_size:
            self.replay_mem.popleft()

    # give a state and action, return the action value
    def get_value(self, s, a):
        s = torch.tensor(s, dtype=torch.float)
        with torch.no_grad():
            val = self.policy_net(s.unsqueeze(0)).gather(
                1,
                torch.tensor(a, dtype=torch.long).unsqueeze(1)).view(1, 1)

        return val.item()

    def save_model(self, save_path='./model/dqn_params'):
        torch.save(self.policy_net.state_dict(), save_path)

    # use the policy net to choose the action with the highest Q value
    def action(self, s, epsilon_greedy=True):
        p = random.random()
        if epsilon_greedy and p <= self.epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            s = torch.tensor(s, dtype=torch.float).unsqueeze(0)

            with torch.no_grad():
                # torch.max gives max value, torch.max[1] gives argmax index
                action = self.policy_net(s).max(dim=1)[1].view(
                    1, 1)  # use view for later item
            return action.item(
            )  # use item() to get the vanilla number instead of a tensor

    # choose an action according to the epsilon-greedy method
    def e_action(self, s):
        p = random.random()
        if p <= self.epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            return self.action(s)
Ejemplo n.º 2
0
class CAC():    
    ''' doc for cac

    parameters:
    --------

    methods:
    --------
    
    '''
    def __init__(self, state_dim, action_dim, mem_size = 10000, train_batch_size = 64, \
                 gamma = 0.99, actor_lr = 1e-4, critic_lr = 1e-4, \
                 action_low = -1.0, action_high = 1.0, tau = 0.1, \
                 sigma = 2, if_PER = True, save_path = '/record/cac'):
        
        self.mem_size, self.train_batch_size = mem_size, train_batch_size
        self.gamma, self.actor_lr, self.critic_lr = gamma, actor_lr, critic_lr
        self.global_step = 0
        self.tau, self.if_PER= tau, if_PER
        self.state_dim, self.action_dim = state_dim, action_dim
        self.replay_mem = PERMemory(mem_size) if if_PER else SlidingMemory(mem_size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.device = 'cpu'
        self.action_low, self.action_high = action_low, action_high
        self.actor_policy_net = CAC_a_fc_network(state_dim, action_dim, action_low, action_high, sigma, self.device).to(self.device)
        self.actor_target_net = CAC_a_fc_network(state_dim, action_dim, action_low, action_high, sigma, self.device).to(self.device)
        self.actor_policy_net = CAC_a_sigma_fc_network(state_dim, action_dim, action_low, action_high, sigma).to(self.device)
        self.actor_target_net = CAC_a_sigma_fc_network(state_dim, action_dim, action_low, action_high, sigma).to(self.device)
        self.critic_policy_net = AC_v_fc_network(state_dim).to(self.device)
        self.critic_target_net = AC_v_fc_network(state_dim).to(self.device)
        self.actor_optimizer = optim.Adam(self.actor_policy_net.parameters(), self.actor_lr)
        self.critic_optimizer = optim.Adam(self.critic_policy_net.parameters(), self.critic_lr)
        self.hard_update(self.actor_target_net, self.actor_policy_net)
        self.hard_update(self.critic_target_net, self.critic_policy_net)
    
    
    
    def soft_update(self, target, source, tau):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

    def hard_update(self, target, source):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(param.data)
    
    #  training process                          
    def train(self, pre_state, action, reward, next_state, if_end):
        
        self.replay_mem.add(pre_state, action, reward, next_state, if_end)
        
        if self.replay_mem.num() < self.mem_size:
            return
        
        # sample $self.train_batch_size$ samples from the replay memory, and use them to train
        if not self.if_PER:
            train_batch = self.replay_mem.sample(self.train_batch_size)
        else:
            train_batch, idx_batch, weight_batch = self.replay_mem.sample(self.train_batch_size)
            weight_batch = torch.tensor(weight_batch, dtype = torch.float).unsqueeze(1)
        
        # adjust dtype to suit the gym default dtype
        pre_state_batch = torch.tensor([x[0] for x in train_batch], dtype=torch.float, device = self.device) 
        action_batch = torch.tensor([x[1] for x in train_batch], dtype = torch.float, device = self.device) 
        # view to make later computation happy
        reward_batch = torch.tensor([x[2] for x in train_batch], dtype=torch.float, device = self.device).view(self.train_batch_size,1)
        next_state_batch = torch.tensor([x[3] for x in train_batch], dtype=torch.float, device = self.device)
        if_end = [x[4] for x in train_batch]
        if_end = torch.tensor(np.array(if_end).astype(float),device = self.device, dtype=torch.float).view(self.train_batch_size,1)
        
        
        # use the target_Q_network to get the target_Q_value
        with torch.no_grad():
            v_next_state = self.critic_target_net(next_state_batch).detach()
            v_target = self.gamma * v_next_state * (1 - if_end) + reward_batch

        v_pred = self.critic_policy_net(pre_state_batch)
        
        if self.if_PER:
            TD_error_batch = np.abs(v_target.cpu().numpy() - v_pred.cpu().detach().numpy())
            self.replay_mem.update(idx_batch, TD_error_batch)
        
        self.critic_optimizer.zero_grad()
        closs = (v_pred - v_target) ** 2 
        if self.if_PER:
            closs *= weight_batch
        closs = closs.mean()
        closs.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_policy_net.parameters(),2)
        self.critic_optimizer.step()
        
        
        self.actor_optimizer.zero_grad()
        
        dist = self.actor_policy_net(pre_state_batch)
        log_action_prob = dist.log_prob(action_batch)
        log_action_prob = torch.sum(log_action_prob, dim = 1)
        entropy = torch.mean(dist.entropy()) * 0.05
        
        with torch.no_grad(): 
            v_next_state = self.critic_policy_net(next_state_batch).detach()
            v_target = self.gamma * v_next_state * (1 - if_end) + reward_batch
            TD_error = v_target - self.critic_policy_net(pre_state_batch).detach()
            
        aloss = -log_action_prob * TD_error
        aloss = aloss.mean() - entropy
        aloss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor_policy_net.parameters(),2)
        self.actor_optimizer.step()
    
        # update target network
        self.soft_update(self.actor_target_net, self.actor_policy_net, self.tau)
        self.soft_update(self.critic_target_net, self.critic_policy_net, self.tau)
        self.global_step += 1
    
    # store the (pre_s, action, reward, next_state, if_end) tuples in the replay memory
    def perceive(self, pre_s, action, reward, next_state, if_end):
        self.replay_mem.append([pre_s, action, reward, next_state, if_end])
        if len(self.replay_mem) > self.mem_size:
            self.replay_mem.popleft()
        
    
    # use the policy net to choose the action with the highest Q value
    def action(self, s, sample = True): # use flag to suit other models' action interface
        s = torch.tensor(s, dtype=torch.float, device = self.device).unsqueeze(0)
        # print(s)
        with torch.no_grad():
            m = self.actor_policy_net(s)
            # print(m)
            a = np.clip(m.sample(), self.action_low, self.action_high) if sample else m.mean
            return a.cpu().numpy()[0]

    def save(self, save_path = None):
        path = save_path if save_path is not None else self.save_path
        torch.save(self.actor_policy_net.state_dict(), path + '_actor.txt' )
        torch.save(self.critic_policy_net.state_dict(), path + '_critic.txt')
    
        
    
        
Ejemplo n.º 3
0
class DDPG():
    '''
    doc for ddpg
    '''
    def __init__(self,
                 state_dim,
                 action_dim,
                 mem_size,
                 train_batch_size,
                 gamma,
                 actor_lr,
                 critic_lr,
                 action_low,
                 action_high,
                 tau,
                 noise,
                 if_PER=True,
                 save_path='./record/ddpg'):
        self.mem_size, self.train_batch_size = mem_size, train_batch_size
        self.gamma, self.actor_lr, self.critic_lr = gamma, actor_lr, critic_lr
        self.global_step = 0
        self.tau, self.explore = tau, noise
        self.state_dim, self.action_dim = state_dim, action_dim
        self.action_high, self.action_low = action_high, action_low
        self.replay_mem = PERMemory(mem_size) if if_PER else SlidingMemory(
            mem_size)
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = 'cpu'
        self.if_PER = if_PER
        self.actor_policy_net = DDPG_actor_network(state_dim, action_dim,
                                                   action_low,
                                                   action_high).to(self.device)
        self.actor_target_net = DDPG_actor_network(state_dim, action_dim,
                                                   action_low,
                                                   action_high).to(self.device)
        self.critic_policy_net = DDPG_critic_network(state_dim, action_dim).to(
            self.device)
        self.critic_target_net = DDPG_critic_network(state_dim, action_dim).to(
            self.device)
        # self.critic_policy_net = NAF_network(state_dim, action_dim, action_low, action_high, self.device).to(self.device)
        # self.critic_target_net = NAF_network(state_dim, action_dim, action_low, action_high, self.device).to(self.device)
        self.critic_policy_net.apply(self._weight_init)
        self.actor_policy_net.apply(self._weight_init)
        self.actor_optimizer = optim.RMSprop(
            self.actor_policy_net.parameters(), self.actor_lr)
        self.critic_optimizer = optim.RMSprop(
            self.critic_policy_net.parameters(), self.critic_lr)
        self.hard_update(self.actor_target_net, self.actor_policy_net)
        self.hard_update(self.critic_target_net, self.critic_policy_net)
        self.save_path = save_path

    def _weight_init(self, m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            torch.nn.init.constant_(m.bias, 0.01)

    def soft_update(self, target, source, tau):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) +
                                    param.data * tau)

    def hard_update(self, target, source):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(param.data)

    #  training process
    def train(self, pre_state, action, reward, next_state, if_end):

        self.replay_mem.add(pre_state, action, reward, next_state, if_end)

        if self.replay_mem.num() < self.mem_size:
            return

        self.explore.decaynoise()

        # sample $self.train_batch_size$ samples from the replay memory, and use them to train
        if not self.if_PER:
            train_batch = self.replay_mem.sample(self.train_batch_size)
        else:
            train_batch, idx_batch, weight_batch = self.replay_mem.sample(
                self.train_batch_size)
            weight_batch = torch.tensor(weight_batch,
                                        dtype=torch.float,
                                        device=self.device).unsqueeze(1)

        # adjust dtype to suit the gym default dtype
        pre_state_batch = torch.tensor([x[0] for x in train_batch],
                                       dtype=torch.float,
                                       device=self.device)
        action_batch = torch.tensor([x[1] for x in train_batch],
                                    dtype=torch.float,
                                    device=self.device)
        # view to make later computation happy
        reward_batch = torch.tensor([x[2] for x in train_batch],
                                    dtype=torch.float,
                                    device=self.device).view(
                                        self.train_batch_size, 1)
        next_state_batch = torch.tensor([x[3] for x in train_batch],
                                        dtype=torch.float,
                                        device=self.device)
        if_end = [x[4] for x in train_batch]
        if_end = torch.tensor(np.array(if_end).astype(float),
                              device=self.device,
                              dtype=torch.float).view(self.train_batch_size, 1)

        # use the target_Q_network to get the target_Q_value
        with torch.no_grad():
            next_action_batch = self.actor_target_net(next_state_batch)
            #print(next_action_batch)
            q_target_ = self.critic_target_net(next_state_batch,
                                               next_action_batch)
            q_target = self.gamma * q_target_ * (1 - if_end) + reward_batch

        q_pred = self.critic_policy_net(pre_state_batch, action_batch)

        if self.if_PER:
            TD_error_batch = np.abs(q_target.cpu().numpy() -
                                    q_pred.detach().cpu().numpy())
            self.replay_mem.update(idx_batch, TD_error_batch)

        self.critic_optimizer.zero_grad()
        closs = (q_pred - q_target)**2
        if self.if_PER:
            closs *= weight_batch

        closs = torch.mean(closs)
        closs.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_policy_net.parameters(), 2)
        self.critic_optimizer.step()

        self.actor_optimizer.zero_grad()
        aloss = -self.critic_policy_net(pre_state_batch,
                                        self.actor_policy_net(pre_state_batch))

        aloss = aloss.mean()
        # print('aloss is {0}'.format(aloss))
        aloss.backward()
        # for para in self.actor_policy_net.parameters():
        #     print('grad is {0}'.format(para.grad))
        # if torch.max(para.grad).numpy() == 0:
        #     raise('why all 0?')
        torch.nn.utils.clip_grad_norm_(self.actor_policy_net.parameters(), 2)
        self.actor_optimizer.step()

        # update target network
        self.soft_update(self.actor_target_net, self.actor_policy_net,
                         self.tau)
        self.soft_update(self.critic_target_net, self.critic_policy_net,
                         self.tau)
        self.global_step += 1

        if self.global_step > 0 and self.global_step % 10000 == 0:
            torch.save(self.actor_policy_net.state_dict(),
                       './record/ddpg_actor_param.txt')
            torch.save(self.critic_policy_net.state_dict(),
                       './record/ddpg_critic_param.txt')

    # store the (pre_s, action, reward, next_state, if_end) tuples in the replay memory
    def perceive(self, pre_s, action, reward, next_state, if_end):
        self.replay_mem.append([pre_s, action, reward, next_state, if_end])
        if len(self.replay_mem) > self.mem_size:
            self.replay_mem.popleft()

    # use the policy net to choose the action with the highest Q value
    def action(self, s, add_noise=True):
        cur_gradient = s[-self.action_dim:]
        s = torch.tensor(s, dtype=torch.float, device=self.device).unsqueeze(0)
        with torch.no_grad():
            action = self.actor_policy_net(s)

        var = np.exp(np.linalg.norm(cur_gradient)) + 0.03
        noise = self.explore.noise() if add_noise else 0.0
        # use item() to get the vanilla number instead of a tensor
        #return [np.clip(np.random.normal(action.item(), self.explore_rate), self.action_low, self.action_high)]
        return np.clip(action.cpu().numpy()[0] + noise, self.action_low,
                       self.action_high)

    def save(self, save_path=None):
        path = save_path if save_path is not None else self.save_path
        torch.save(self.actor_policy_net.state_dict(), path + '_actor.txt')
        torch.save(self.critic_policy_net.state_dict(), path + '_critic.txt')
Ejemplo n.º 4
0
class AC():
    def __init__(self,
                 state_dim,
                 action_dim,
                 mem_size,
                 train_batch_size,
                 gamma,
                 actor_lr,
                 critic_lr,
                 tau,
                 if_PER=True):
        self.mem_size, self.train_batch_size = mem_size, train_batch_size
        self.gamma, self.actor_lr, self.critic_lr = gamma, actor_lr, critic_lr
        self.global_step = 0
        self.tau, self.if_PER = tau, if_PER
        self.state_dim, self.action_dim = state_dim, action_dim
        self.replay_mem = PERMemory(mem_size) if if_PER else SlidingMemory(
            mem_size)
        #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = 'cpu'
        self.cret = nn.MSELoss()
        self.actor_policy_net = AC_a_fc_network(state_dim,
                                                action_dim).to(self.device)
        self.actor_target_net = AC_a_fc_network(state_dim,
                                                action_dim).to(self.device)
        self.critic_policy_net = AC_v_fc_network(state_dim).to(self.device)
        self.critic_target_net = AC_v_fc_network(state_dim).to(self.device)
        self.actor_optimizer = optim.Adam(self.actor_policy_net.parameters(),
                                          self.actor_lr)
        self.critic_optimizer = optim.Adam(self.critic_policy_net.parameters(),
                                           self.critic_lr)
        self.hard_update(self.actor_target_net, self.actor_policy_net)
        self.hard_update(self.critic_target_net, self.critic_policy_net)

    def soft_update(self, target, source, tau):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) +
                                    param.data * tau)

    def hard_update(self, target, source):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(param.data)

    #  training process
    def train(self, pre_state, action, reward, next_state, if_end):

        self.replay_mem.add(pre_state, action, reward, next_state, if_end)

        if self.replay_mem.num() < self.mem_size:
            return

        # sample $self.train_batch_size$ samples from the replay memory, and use them to train
        if not self.if_PER:
            train_batch = self.replay_mem.sample(self.train_batch_size)
        else:
            train_batch, idx_batch, weight_batch = self.replay_mem.sample(
                self.train_batch_size)
            weight_batch = torch.tensor(weight_batch,
                                        dtype=torch.float).unsqueeze(1)

        # adjust dtype to suit the gym default dtype
        pre_state_batch = torch.tensor([x[0] for x in train_batch],
                                       dtype=torch.float,
                                       device=self.device)
        action_batch = torch.tensor([x[1] for x in train_batch],
                                    dtype=torch.long,
                                    device=self.device)
        # view to make later computation happy
        reward_batch = torch.tensor([x[2] for x in train_batch],
                                    dtype=torch.float,
                                    device=self.device).view(
                                        self.train_batch_size, 1)
        next_state_batch = torch.tensor([x[3] for x in train_batch],
                                        dtype=torch.float,
                                        device=self.device)
        if_end = [x[4] for x in train_batch]
        if_end = torch.tensor(np.array(if_end).astype(float),
                              device=self.device,
                              dtype=torch.float).view(self.train_batch_size, 1)

        # use the target_Q_network to get the target_Q_value
        with torch.no_grad():
            v_next_state = self.critic_target_net(next_state_batch).detach()
            v_target = self.gamma * v_next_state * (1 - if_end) + reward_batch

        v_pred = self.critic_policy_net(pre_state_batch)

        if self.if_PER:
            TD_error_batch = np.abs(v_target.numpy() - v_pred.detach().numpy())
            self.replay_mem.update(idx_batch, TD_error_batch)

        self.critic_optimizer.zero_grad()
        closs = (v_pred - v_target)**2
        if self.if_PER:
            closs *= weight_batch
        closs = closs.mean()
        closs.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_policy_net.parameters(), 1)
        self.critic_optimizer.step()

        self.actor_optimizer.zero_grad()

        action_prob = self.actor_policy_net(pre_state_batch).gather(
            1, action_batch.unsqueeze(1))
        log_action_prob = torch.log(action_prob.clamp(min=1e-15))

        with torch.no_grad():
            v_next_state = self.critic_policy_net(next_state_batch).detach()
            v_target = self.gamma * v_next_state * (1 - if_end) + reward_batch
            TD_error = v_target - self.critic_policy_net(
                pre_state_batch).detach()

        aloss = -log_action_prob * TD_error
        aloss = aloss.mean()

        aloss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor_policy_net.parameters(), 1)
        self.actor_optimizer.step()

        # update target network
        self.soft_update(self.actor_target_net, self.actor_policy_net,
                         self.tau)
        self.soft_update(self.critic_target_net, self.critic_policy_net,
                         self.tau)
        self.global_step += 1

    # store the (pre_s, action, reward, next_state, if_end) tuples in the replay memory
    def perceive(self, pre_s, action, reward, next_state, if_end):
        self.replay_mem.append([pre_s, action, reward, next_state, if_end])
        if len(self.replay_mem) > self.mem_size:
            self.replay_mem.popleft()

    # use the policy net to choose the action with the highest Q value
    def action(self,
               s,
               sample=True):  # use flag to suit other models' action interface
        s = torch.tensor(s, dtype=torch.float, device=self.device).unsqueeze(0)
        with torch.no_grad():
            action_prob = self.actor_policy_net(s)
            return np.random.choice(self.action_dim, p=action_prob.numpy()[0])
class NAF():
    '''
    doc for naf
    '''
    def __init__(self, args, noise, flag=False, if_PER=False):
        self.args = args
        self.mem_size, self.train_batch_size = args.replay_size, args.batch_size
        self.gamma, self.lr = args.gamma, args.lr
        self.global_step = 0
        self.tau, self.explore = args.tau, noise
        self.state_dim, self.action_dim = args.state_dim, args.action_dim
        self.action_high, self.action_low = args.action_high, args.action_low
        self.if_PER = if_PER
        self.replay_mem = PERMemory(
            self.mem_size) if if_PER else SlidingMemory(self.mem_size)
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net = NAF_network(self.state_dim, self.action_dim,
                                      self.action_low, self.action_high,
                                      self.device).to(self.device)
        self.target_net = NAF_network(self.state_dim, self.action_dim,
                                      self.action_low, self.action_high,
                                      self.device).to(self.device)
        self.policy_net.apply(self._weight_init)
        if self.args.optimizer == 'adam':
            self.optimizer = optim.Adam(self.policy_net.parameters(), self.lr)
        elif self.args.optimizer == 'rmsprop':
            self.optimizer = optim.RMSprop(self.policy_net.parameters(),
                                           self.lr)
        else:
            print('Invalied Optimizer!')
            exit()
        self.hard_update(self.target_net, self.policy_net)

        self.flag = flag

    def _weight_init(self, m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_normal_(m.weight)
            torch.nn.init.constant_(m.bias, 0.01)

    def soft_update(self, target, source, tau):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) +
                                    param.data * tau)

    def hard_update(self, target, source):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(param.data)

    ###  training process
    def train(self, pre_state, action, reward, next_state, if_end):

        self.replay_mem.add(pre_state, action, reward, next_state, if_end)

        if self.replay_mem.num() == self.mem_size - 1:
            print('Replay Memory Filled, Now Start Training!')

        if self.replay_mem.num() < self.mem_size:
            return

        ### sample $self.train_batch_size$ samples from the replay memory, and use them to train
        if not self.if_PER:
            train_batch = self.replay_mem.sample(self.train_batch_size)
        else:
            train_batch, idx_batch, weight_batch = self.replay_mem.sample(
                self.train_batch_size)
            weight_batch = torch.tensor(weight_batch,
                                        dtype=torch.float).unsqueeze(1)

        pre_state_batch = torch.tensor([x[0] for x in train_batch],
                                       dtype=torch.float,
                                       device=self.device)
        action_batch = torch.tensor([x[1] for x in train_batch],
                                    dtype=torch.float,
                                    device=self.device)
        reward_batch = torch.tensor([x[2] for x in train_batch],
                                    dtype=torch.float,
                                    device=self.device).unsqueeze(
                                        1)  #.view(self.train_batch_size,1)
        next_state_batch = torch.tensor([x[3] for x in train_batch],
                                        dtype=torch.float,
                                        device=self.device)
        if_end = [x[4] for x in train_batch]
        if_end = torch.tensor(np.array(if_end).astype(float),
                              device=self.device,
                              dtype=torch.float).unsqueeze(
                                  1)  #.view(self.train_batch_size,1)

        ### use the target_Q_network to get the target_Q_value
        with torch.no_grad():
            q_target_, _ = self.target_net(next_state_batch)
            q_target = self.gamma * q_target_ * (1 - if_end) + reward_batch

        q_pred = self.policy_net(pre_state_batch, action_batch)

        if self.if_PER:
            TD_error_batch = np.abs(q_target.cpu().numpy() -
                                    q_pred.cpu().detach().numpy())
            self.replay_mem.update(idx_batch, TD_error_batch)

        self.optimizer.zero_grad()
        loss = (q_pred - q_target)**2
        if self.if_PER:
            loss *= weight_batch

        loss = torch.mean(loss)
        if self.flag:
            loss -= q_pred.mean()  # to test one of my ideas
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1)
        self.optimizer.step()

        ### update target network
        self.soft_update(self.target_net, self.policy_net, self.tau)
        # self.hard_update(self.target_net, self.policy_net)
        self.global_step += 1

        ### decrease explore ratio
        if self.global_step % self.args.explore_decrease_every == 0:
            self.explore.decrease(self.args.explore_decrease)

    ### store the (pre_s, action, reward, next_state, if_end) tuples in the replay memory
    def perceive(self, pre_s, action, reward, next_state, if_end):
        self.replay_mem.append([pre_s, action, reward, next_state, if_end])
        if len(self.replay_mem) > self.mem_size:
            self.replay_mem.popleft()

    ### give a state and action, return the action value
    def get_value(self, s, a):
        s = torch.tensor(s, dtype=torch.float, device=self.device)
        with torch.no_grad():
            val = self.policy_net(s.unsqueeze(0)).gather(
                1,
                torch.tensor(a, dtype=torch.long).unsqueeze(1)).view(1, 1)

        return np.clip(val.item() + np.random.rand(1, self.explore_rate),
                       self.action_low, self.action_high)

    ### use the policy net to choose the action with the highest Q value
    def action(self, s, add_noise=True):
        s = torch.tensor(s, dtype=torch.float, device=self.device)
        with torch.no_grad():
            _, action = self.policy_net(s)

        action = action.cpu().numpy()
        if add_noise:
            noise = [[self.explore.noise() for j in range(action.shape[1])]
                     for i in range(action.shape[0])]
        else:
            noise = [[0 for j in range(action.shape[1])]
                     for i in range(action.shape[0])]

        noise = np.array(noise)
        action += noise
        action = np.exp(action) / np.sum(np.exp(action), axis=1).reshape(-1, 1)
        return action

    def save(self, save_path=None):
        torch.save(self.policy_net.state_dict(), save_path + 'nafnet.txt')

    def load(self, load_path=None):
        self.policy_net.load_state_dict(torch.load(load_path + 'nafnet.txt'))
        self.hard_update(self.target_net, self.policy_net)

    def set_explore(self, error):
        explore_rate = 0.5 * np.log(error) / np.log(10)
        if self.replay_mem.num() >= self.mem_size:
            print('reset explore rate as: ', explore_rate)
            self.explore.setnoise(explore_rate)
Ejemplo n.º 6
0
class NAF():
    def __init__(self,
                 state_dim,
                 action_dim,
                 mem_size,
                 train_batch_size,
                 gamma,
                 lr,
                 action_low,
                 action_high,
                 tau,
                 noise,
                 flag,
                 if_PER=True):
        self.mem_size, self.train_batch_size = mem_size, train_batch_size
        self.gamma, self.lr = gamma, lr
        self.global_step = 0
        self.tau, self.explore = tau, noise
        self.state_dim, self.action_dim = state_dim, action_dim
        self.action_high, self.action_low = action_high, action_low
        self.if_PER = if_PER
        self.replay_mem = PERMemory(mem_size) if if_PER else SlidingMemory(
            mem_size)
        #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = 'cpu'
        self.policy_net = NAF_network(state_dim, action_dim, action_low,
                                      action_high).to(self.device)
        self.target_net = NAF_network(state_dim, action_dim, action_low,
                                      action_high).to(self.device)
        self.optimizer = optim.Adam(self.policy_net.parameters(), self.lr)
        self.hard_update(self.target_net, self.policy_net)

        self.flag = flag

    def soft_update(self, target, source, tau):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) +
                                    param.data * tau)

    def hard_update(self, target, source):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(param.data)

    #  training process
    def train(self, pre_state, action, reward, next_state, if_end):

        self.replay_mem.add(pre_state, action, reward, next_state, if_end)

        if self.replay_mem.num() < self.mem_size:
            return

        self.explore.decaynoise()

        # sample $self.train_batch_size$ samples from the replay memory, and use them to train
        if not self.if_PER:
            train_batch = self.replay_mem.sample(self.train_batch_size)
        else:
            train_batch, idx_batch, weight_batch = self.replay_mem.sample(
                self.train_batch_size)
            weight_batch = torch.tensor(weight_batch,
                                        dtype=torch.float).unsqueeze(1)

        # adjust dtype to suit the gym default dtype
        pre_state_batch = torch.tensor([x[0] for x in train_batch],
                                       dtype=torch.float,
                                       device=self.device)
        action_batch = torch.tensor([x[1] for x in train_batch],
                                    dtype=torch.float,
                                    device=self.device)
        # view to make later computation happy
        reward_batch = torch.tensor([x[2] for x in train_batch],
                                    dtype=torch.float,
                                    device=self.device).view(
                                        self.train_batch_size, 1)
        next_state_batch = torch.tensor([x[3] for x in train_batch],
                                        dtype=torch.float,
                                        device=self.device)
        if_end = [x[4] for x in train_batch]
        if_end = torch.tensor(np.array(if_end).astype(float),
                              device=self.device,
                              dtype=torch.float).view(self.train_batch_size, 1)

        # use the target_Q_network to get the target_Q_value
        with torch.no_grad():
            q_target_, _ = self.target_net(next_state_batch)
            q_target = self.gamma * q_target_ * (1 - if_end) + reward_batch

        q_pred = self.policy_net(pre_state_batch, action_batch)

        if self.if_PER:
            TD_error_batch = np.abs(q_target.numpy() - q_pred.detach().numpy())
            self.replay_mem.update(idx_batch, TD_error_batch)

        self.optimizer.zero_grad()
        loss = (q_pred - q_target)**2
        if self.if_PER:
            loss *= weight_batch

        loss = torch.mean(loss)
        if self.flag:
            loss -= q_pred.mean()  # to test one of my ideas
        loss.backward()
        torch.nn.utils.clip_grad_norm(self.policy_net.parameters(), 1)
        self.optimizer.step()

        # update target network
        self.soft_update(self.target_net, self.policy_net, self.tau)
        self.global_step += 1

    # store the (pre_s, action, reward, next_state, if_end) tuples in the replay memory
    def perceive(self, pre_s, action, reward, next_state, if_end):
        self.replay_mem.append([pre_s, action, reward, next_state, if_end])
        if len(self.replay_mem) > self.mem_size:
            self.replay_mem.popleft()

    # give a state and action, return the action value
    def get_value(self, s, a):
        s = torch.tensor(s, dtype=torch.float, device=self.device)
        with torch.no_grad():
            val = self.policy_net(s.unsqueeze(0)).gather(
                1,
                torch.tensor(a, dtype=torch.long).unsqueeze(1)).view(1, 1)

        return np.clip(val.item() + np.random.rand(1, self.explore_rate),
                       self.action_low, self.action_high)

    # use the policy net to choose the action with the highest Q value
    def action(self, s, add_noise=True):
        s = torch.tensor(s, dtype=torch.float, device=self.device).unsqueeze(0)
        with torch.no_grad():
            _, action = self.policy_net(s)

        noise = self.explore.noise() if add_noise else 0.0
        # use item() to get the vanilla number instead of a tensor
        #return [np.clip(np.random.normal(action.item(), self.explore_rate), self.action_low, self.action_high)ac]
        #print(action.numpy()[0])
        return np.clip(action.numpy()[0] + noise, self.action_low,
                       self.action_high)
Ejemplo n.º 7
0
class DQN():    
    '''
    Doc for DQN
    '''
    def __init__(self, args, if_dueling = True, if_PER = False):
        self.args = args
        self.mem_size, self.train_batch_size = args.replay_size, args.batch_size
        self.gamma, self.lr = args.gamma, args.lr
        self.global_step = 0
        self.tau = args.tau
        self.state_dim, self.action_dim = args.state_dim, args.action_dim
        self.if_PER = if_PER
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.replay_mem = PERMemory(self.mem_size) if if_PER else SlidingMemory(self.mem_size)
        self.policy_net = DQN_fc_network(self.state_dim, self.action_dim, hidden_layers=1).to(self.device)
        self.target_net = DQN_fc_network(self.state_dim, self.action_dim, hidden_layers=1).to(self.device)
        self.epsilon = 1.0
        
        if if_dueling:
            self.policy_net = DQN_dueling_network(self.state_dim, self.action_dim, hidden_layers= 1).to(self.device)
            self.target_net = DQN_dueling_network(self.state_dim, self.action_dim, hidden_layers= 1).to(self.device)
        
        if args.formulation == 'FCONV':
            self.policy_net = DQN_FCONV_network(self.args.window_size, self.action_dim).to(self.device)
            self.target_net = DQN_FCONV_network(self.args.window_size, self.action_dim).to(self.device)

        self.policy_net.apply(self._weight_init)
        self.hard_update(self.target_net, self.policy_net)
        if args.optimizer == 'adam':
            self.optimizer = optim.Adam(self.policy_net.parameters(), self.lr)
        elif args.optimizer == 'rmsprop':
            self.optimizer = optim.RMSprop(self.policy_net.parameters(), self.lr)
        else:
            print('Error: Invalid Optimizer')
            exit()
        self.hard_update(self.target_net, self.policy_net)

    def _weight_init(self,m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_normal_(m.weight)
            torch.nn.init.constant_(m.bias, 0.01)
           
    def soft_update(self, target, source, tau):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

    def hard_update(self, target, source):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(param.data)        

    #  training process                          
    def train(self, pre_state, action, reward, next_state, if_end):
        
        self.replay_mem.add(pre_state, action, reward, next_state, if_end)
        
        if self.replay_mem.num() == self.train_batch_size - 1:
            print('Replay Memory is Filled, now begin training!')

        if self.replay_mem.num() < self.train_batch_size:
            return
        
        # sample $self.train_batch_size$ samples from the replay memory, and use them to train
        if not self.if_PER:
            train_batch = self.replay_mem.sample(self.train_batch_size)
        else:
            train_batch, idx_batch, weight_batch = self.replay_mem.sample(self.train_batch_size)
            weight_batch = torch.tensor(weight_batch, dtype = torch.float).unsqueeze(1)
            
        pre_state_batch = torch.tensor([x[0] for x in train_batch], dtype=torch.float, device = self.device).squeeze()
        # if self.args.debug:
        #     print('before squeeze, pre_state_batch shape is: ', pre_state_batch.shape)
        #     print('after squeeze, pre_state_batch_shape is: ', pre_state_batch.squeeze().shape)
        action_batch = torch.tensor([x[1] for x in train_batch], dtype = torch.long, device = self.device).unsqueeze(1) # dtype = long for gater
        reward_batch = torch.tensor([x[2] for x in train_batch], dtype=torch.float, device = self.device).unsqueeze(1)
        next_state_batch = torch.tensor([x[3] for x in train_batch], dtype=torch.float, device = self.device).squeeze()
        if_end = [x[4] for x in train_batch]
        if_end = torch.tensor(np.array(if_end).astype(float), dtype=torch.float, device = self.device)
        
        # use the target_Q_network to get the target_Q_value
        # torch.max[0] gives the max value, torch.max[1] gives the argmax index
        
        # vanilla dqn
        #q_target_ = self.target_net(next_state_batch).max(1)[0].detach() # detach to not bother the gradient
        #q_target_ = q_target_.view(self.train_batch_size,1)
        
        ### double dqn
        with torch.no_grad():
            next_best_action = self.policy_net(next_state_batch).max(1)[1].detach().unsqueeze(1)
            # if self.args.debug:
            #     print('Next State Batch Shape Is: ', next_state_batch.shape)
            #     print('Next Best Action Shape Is: ', next_best_action.shape)
            q_target_ = self.target_net(next_state_batch).gather(1, next_best_action)
            # if self.args.debug:
            #     print('q_target_ Shape Is: ', q_target_.shape)
            
        if_end = if_end.view_as(q_target_)
        q_target = self.gamma * q_target_ * ( 1 - if_end) + reward_batch
        q_pred = self.policy_net(pre_state_batch).gather(1, action_batch) 
        
        if self.if_PER:
            TD_error_batch = np.abs(q_target.numpy() - q_pred.detach().numpy())
            self.replay_mem.update(idx_batch, TD_error_batch)
        
        self.optimizer.zero_grad()
        
        loss = (q_pred - q_target) ** 2 
        if self.if_PER:
            loss *= weight_batch
            
        loss = torch.mean(loss)    
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1)
        self.optimizer.step()
    
        ### soft update target network
        # self.soft_update(self.target_net, self.policy_net, self.tau)
        
        ### decrease exploration rate
        self.global_step += 1
        self.epsilon *= self.args.explore_decay
        # if self.global_step % self.args.explore_decrease_every == 0:
            # self.epsilon = max(self.args.explore_final, self.epsilon - self.args.explore_decrease)
            

        ### hard update target network
        if self.global_step % self.args.update_every == 0:
            self.hard_update(self.target_net, self.policy_net)

    # store the (pre_s, action, reward, next_state, if_end) tuples in the replay memory
    def perceive(self, pre_s, action, reward, next_state, if_end):
        self.replay_mem.append([pre_s, action, reward, next_state, if_end])
        if len(self.replay_mem) > self.mem_size:
            self.replay_mem.popleft()
            
    # give a state and action, return the action value
    def get_value(self, s, a):
        s = torch.tensor(s,dtype=torch.float)
        a = torch.tensor(a,dtype = torch.long).unsqueeze(0)
        with torch.no_grad():
            val = self.policy_net(s).gather(1, a).cpu().numpy()
            
        return val
    
    def save(self, save_path):
        torch.save(self.policy_net.state_dict(), save_path + '_DQN.txt')
        
        
    # use the policy net to choose the action with the highest Q value
    def action(self, s, epsilon_greedy = True):
        s = torch.tensor(s, dtype=torch.float, device = self.device) 
        p = random.random() 
        if epsilon_greedy and p <= self.epsilon:
            if self.args.formulation == 'MLP':
                return [random.randint(0, self.action_dim - 1) for i in range(s.shape[0])]
            else:
                action = [random.randint(0, self.action_dim - 1) for i in range(s.shape[2] - self.args.window_size * 2)]
                # if self.args.debug:
                #     print('In Action, random action length is: ', len(action))
                return action
        else:
            with torch.no_grad():
            # torch.max gives max value, torch.max[1] gives argmax index
                _, action = self.policy_net(s).max(dim=1)
                action = action.cpu().numpy()
            
            if self.args.formulation == 'FCONV':
                action = action[0]

            # if self.args.debug:
            #     print('In Action, action.shape is: ', action.shape)
            return action 
    
    # choose an action according to the epsilon-greedy method
    # def e_action(self, s):
    #     p = random.random()
    #     if p <= self.epsilon:
    #         return random.randint(0, self.action_dim - 1)
    #     else:
    #         return self.action(s)

    def load(self, load_path):
        self.policy_net.load_state_dict(torch.load(load_path + '_DQN.txt'))
        self.hard_update(self.target_net, self.policy_net)

    def set_explore(self, error):
        explore_rate = 0.1 * np.log(error) / np.log(10)
        if self.replay_mem.num() >= self.mem_size:
            print('reset explore rate as: ', explore_rate)
            self.epsilon = explore_rate
Ejemplo n.º 8
0
class NAF():    
    '''
    doc for naf
    '''
    def __init__(self, state_dim, action_dim, mem_size, train_batch_size, gamma, lr,
                 action_low, action_high, tau, noise, flag = False, if_PER = False, 
                 save_path = './record/NAF'):
        self.mem_size, self.train_batch_size = mem_size, train_batch_size
        self.gamma, self.lr = gamma, lr
        self.global_step = 0
        self.tau, self.explore = tau, noise
        self.state_dim, self.action_dim = state_dim, action_dim
        self.action_high, self.action_low = action_high, action_low
        self.if_PER = if_PER
        self.replay_mem = PERMemory(mem_size) if if_PER else SlidingMemory(mem_size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = 'cpu'
        self.policy_net = NAF_network(state_dim, action_dim, action_low, action_high, self.device).to(self.device)
        self.target_net = NAF_network(state_dim, action_dim,action_low, action_high, self.device).to(self.device)
        self.policy_net.apply(self._weight_init)
        self.optimizer = optim.RMSprop(self.policy_net.parameters(), self.lr)
        self.hard_update(self.target_net, self.policy_net)
        
        self.flag = flag
    
    def _weight_init(self,m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
            torch.nn.init.constant_(m.bias, 0.)
    
    def soft_update(self, target, source, tau):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

    def hard_update(self, target, source):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(param.data)
    
    #  training process                          
    def train(self, pre_state, action, reward, next_state, if_end):
        
        self.replay_mem.add(pre_state, action, reward, next_state, if_end)
        
        if self.replay_mem.num() < self.mem_size:
            return
        
        self.explore.decaynoise()
        
        # print('training')
        
        # sample $self.train_batch_size$ samples from the replay memory, and use them to train
        if not self.if_PER:
            train_batch = self.replay_mem.sample(self.train_batch_size)
        else:
            train_batch, idx_batch, weight_batch = self.replay_mem.sample(self.train_batch_size)
            weight_batch = torch.tensor(weight_batch, dtype = torch.float).unsqueeze(1)
        
        # adjust dtype to suit the gym default dtype
        pre_state_batch = torch.tensor([x[0] for x in train_batch], dtype=torch.float, device = self.device) 
        action_batch = torch.tensor([x[1] for x in train_batch], dtype = torch.float, device = self.device) 
        # view to make later computation happy
        reward_batch = torch.tensor([x[2] for x in train_batch], dtype=torch.float, device = self.device).view(self.train_batch_size,1)
        next_state_batch = torch.tensor([x[3] for x in train_batch], dtype=torch.float, device = self.device)
        if_end = [x[4] for x in train_batch]
        if_end = torch.tensor(np.array(if_end).astype(float),device = self.device, dtype=torch.float).view(self.train_batch_size,1)
        
        # use the target_Q_network to get the target_Q_value
        with torch.no_grad():
            q_target_, _ = self.target_net(next_state_batch)
            q_target = self.gamma * q_target_ * (1 - if_end) + reward_batch


        q_pred = self.policy_net(pre_state_batch, action_batch)
        
        if self.if_PER:
            TD_error_batch = np.abs(q_target.cpu().numpy() - q_pred.cpu().detach().numpy())
            self.replay_mem.update(idx_batch, TD_error_batch)
        
        self.optimizer.zero_grad()
        #print('q_pred is {0} and q_target is {1}'.format(q_pred, q_target))
        loss = (q_pred - q_target) ** 2 
        if self.if_PER:
            loss *= weight_batch
            
        loss = torch.mean(loss)
        if self.flag:
            loss -= q_pred.mean() # to test one of my ideas
        loss.backward()
        # loss = torch.min(loss, torch.tensor(1000, dtype = torch.float))
        # print('loss is {0}'.format(loss))
        # # for para in self.policy_net.parameters():
        #     print('param is: \n {0} \n gradient is: \n {1}'.format(para, para.grad))
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 2)
        self.optimizer.step()
    
        # update target network
        self.soft_update(self.target_net, self.policy_net, self.tau)
        self.global_step += 1
    
    # store the (pre_s, action, reward, next_state, if_end) tuples in the replay memory
    def perceive(self, pre_s, action, reward, next_state, if_end):
        self.replay_mem.append([pre_s, action, reward, next_state, if_end])
        if len(self.replay_mem) > self.mem_size:
            self.replay_mem.popleft()
            
            
    # give a state and action, return the action value
    def get_value(self, s, a):
        s = torch.tensor(s,dtype=torch.float, device = self.device)
        with torch.no_grad():
            val = self.policy_net(s.unsqueeze(0)).gather(1, torch.tensor(a,dtype = torch.long).unsqueeze(1)).view(1,1)
            
        return np.clip(val.item() + np.random.rand(1, self.explore_rate), self.action_low, self.action_high)
        
    
    # use the policy net to choose the action with the highest Q value
    def action(self, s, add_noise = True):
        cur_gradient = s[-self.action_dim:]
        s = torch.tensor(s, dtype=torch.float, device = self.device).unsqueeze(0)
        with torch.no_grad():
            _, action = self.policy_net(s) 
        
        var = np.exp(np.linalg.norm(cur_gradient)) + 0.03
        noise = self.explore.noise() if add_noise else 0.0
        # use item() to get the vanilla number instead of a tensor
        #return [np.clip(np.random.normal(action.item(), self.explore_rate), self.action_low, self.action_high)ac]
        #print(action.numpy()[0])
        return np.clip(action.cpu().numpy()[0] + noise, self.action_low, self.action_high)

    def save(self, save_path = None):
        path = save_path if save_path is not None else self.save_path
        torch.save(self.policy_net.state_dict(), path + '_critic.txt')
class AC():  
    """
    DOCstring for actor-critic
    """  
    def __init__(self, args, if_PER = False):
        self.args = args
        self.mem_size, self.train_batch_size = args.replay_size, args.batch_size
        self.gamma = args.gamma
        self.actor_lr = args.a_lr
        self.critic_lr = args.c_lr
        self.global_step = 0
        self.tau = args.tau
        self.if_PER = if_PER
        self.state_dim, self.action_dim = args.state_dim, args.action_dim
        self.replay_mem = PERMemory(self.mem_size) if if_PER else SlidingMemory(self.mem_size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.device = 'cpu'
        self.cret = nn.MSELoss()
        self.actor_policy_net = AC_a_fc_network(self.state_dim, self.action_dim).to(self.device)
        self.actor_target_net = AC_a_fc_network(self.state_dim, self.action_dim).to(self.device)
        self.critic_policy_net = AC_v_fc_network(self.state_dim).to(self.device)
        self.critic_target_net = AC_v_fc_network(self.state_dim).to(self.device)
        self.critic_policy_net.apply(self._weight_init)
        self.actor_policy_net.apply(self._weight_init)
        self.actor_optimizer = optim.Adam(self.actor_policy_net.parameters(), self.actor_lr)
        self.critic_optimizer = optim.Adam(self.critic_policy_net.parameters(), self.critic_lr)
        self.hard_update(self.actor_target_net, self.actor_policy_net)
        self.hard_update(self.critic_target_net, self.critic_policy_net)
        self.save_path = './record/'

    def _weight_init(self,m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_normal_(m.weight)
            torch.nn.init.constant_(m.bias, 0.01)
    
    
    def soft_update(self, target, source, tau):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

    def hard_update(self, target, source):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(param.data)
    
    #  training process                          
    def train(self, pre_state, action, reward, next_state, if_end):
        
        self.replay_mem.add(pre_state, action, reward, next_state, if_end)
        
        if self.replay_mem.num() == self.mem_size - 1:
            print('replay memory is filled, now begin training!')  

        if self.replay_mem.num() < self.mem_size:
            return      

        # sample $self.train_batch_size$ samples from the replay memory, and use them to train
        if not self.if_PER:
            train_batch = self.replay_mem.sample(self.train_batch_size)
        else:
            train_batch, idx_batch, weight_batch = self.replay_mem.sample(self.train_batch_size)
            weight_batch = torch.tensor(weight_batch, dtype = torch.float).unsqueeze(1)
        
        # adjust dtype to suit the gym default dtype
        pre_state_batch = torch.tensor([x[0] for x in train_batch], dtype=torch.float, device = self.device).squeeze()
        action_batch = torch.tensor([x[1] for x in train_batch], dtype = torch.long, device = self.device).unsqueeze(1)
        # view to make later computation happy
        reward_batch = torch.tensor([x[2] for x in train_batch], dtype=torch.float, device = self.device).view(self.train_batch_size,1)
        next_state_batch = torch.tensor([x[3] for x in train_batch], dtype=torch.float, device = self.device).squeeze()
        if_end = [x[4] for x in train_batch]
        if_end = torch.tensor(np.array(if_end).astype(float),device = self.device, dtype=torch.float).view(self.train_batch_size,1)
        
        
        # use the target_Q_network to get the target_Q_value
        with torch.no_grad():
            v_next_state = self.critic_target_net(next_state_batch).detach()
            v_target = self.gamma * v_next_state * (1 - if_end) + reward_batch

        v_pred = self.critic_policy_net(pre_state_batch)
        
        if self.if_PER:
            TD_error_batch = np.abs(v_target.numpy() - v_pred.detach().numpy())
            self.replay_mem.update(idx_batch, TD_error_batch)
        
        self.critic_optimizer.zero_grad()
        closs = (v_pred - v_target) ** 2 
        if self.if_PER:
            closs *= weight_batch
        closs = closs.mean()
        closs.backward()
        torch.nn.utils.clip_grad_norm_(self.critic_policy_net.parameters(),1)
        self.critic_optimizer.step()
        
        
        self.actor_optimizer.zero_grad()
        action_prob = self.actor_policy_net(pre_state_batch)
        # if self.args.debug:
        #     print('Prestate Batch Shape Is: ', pre_state_batch.shape)
        #     print('Torch Action Prob Shape Is: ', action_prob.shape)
        #     print('Action Batch Shape Is: ', action_batch.shape)
            # print('Action Batch Unsqueeze Shape Is: ', action_batch.unsqu(1).shape)
        action_prob = action_prob.gather(1, action_batch)
        log_action_prob = torch.log(action_prob.clamp(min = 1e-10))
   
        with torch.no_grad(): 
            v_next_state = self.critic_policy_net(next_state_batch).detach()
            v_target = self.gamma * v_next_state * (1 - if_end) + reward_batch
            TD_error = v_target - self.critic_policy_net(pre_state_batch).detach()
        
        aloss = - log_action_prob * TD_error
        aloss = aloss.mean()
 
        aloss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor_policy_net.parameters(),1)
        self.actor_optimizer.step()
    
        # update target network
        self.soft_update(self.actor_target_net, self.actor_policy_net, self.tau)
        self.soft_update(self.critic_target_net, self.critic_policy_net, self.tau)
        self.global_step += 1
    
    # store the (pre_s, action, reward, next_state, if_end) tuples in the replay memory
    def perceive(self, pre_s, action, reward, next_state, if_end):
        self.replay_mem.append([pre_s, action, reward, next_state, if_end])
        if len(self.replay_mem) > self.mem_size:
            self.replay_mem.popleft()
        
    
    # use the policy net to choose the action with the highest Q value
    def action(self, s, test = True): # use flag to suit other models' action interface
        s = torch.tensor(s, dtype=torch.float, device = self.device)
        with torch.no_grad():
            action_prob = self.actor_policy_net(s).cpu().numpy() 
            num = action_prob.shape[0]
            # if self.args.debug:
            #     print('action prob is: ', action_prob)

        if test == False:
            return [np.random.choice(self.action_dim, p = action_prob[i]) for i in range(num)]
        else:
            # print(action_prob)
            return np.argmax(action_prob, axis = 1)

    def save(self, save_path = None):
        path = save_path if save_path is not None else self.save_path
        torch.save(self.actor_policy_net.state_dict(), path + '_actor_AC.txt' )
        torch.save(self.critic_policy_net.state_dict(), path + '_critic_AC.txt')

    def load(self, load_path):
        self.critic_policy_net.load_state_dict(torch.load(load_path + '_critic_AC.txt'))
        self.actor_policy_net.load_state_dict(torch.load(load_path + '_actor_AC.txt'))