Example #1
0
    def __init__(self, s_dim, a_num, skill_num, hidden, lr, gamma, tau,
                 log_prob_reg, alpha, capacity, batch_size, device):
        self.s_dim = s_dim
        self.a_num = a_num
        self.skill_num = skill_num
        hidden = hidden
        self.lr = lr
        self.gamma = gamma
        self.tau = tau
        self.log_prob_reg = log_prob_reg
        self.alpha = alpha
        self.capacity = capacity
        self.batch_size = batch_size
        self.device = device
        self.log_pz = torch.log(
            torch.tensor(1 / skill_num, dtype=torch.float, device=device))

        # network initialization
        self.policy = Policy(s_dim, skill_num, hidden, a_num).to(device)
        self.opt_policy = torch.optim.Adam(self.policy.parameters(), lr=lr)

        self.q_net = QNet(s_dim, skill_num, hidden, a_num).to(device)
        self.opt_q_net = torch.optim.Adam(self.q_net.parameters(), lr=lr)

        self.v_net = VNet(s_dim, skill_num, hidden).to(device)
        self.v_net_target = VNet(s_dim, skill_num, hidden).to(device)
        self.v_net_target.load_state_dict(self.v_net.state_dict())
        self.opt_v_net = torch.optim.Adam(self.v_net.parameters(), lr=lr)

        self.discriminator = Discriminator(s_dim, skill_num, hidden).to(device)
        self.opt_discriminator = torch.optim.Adam(
            self.discriminator.parameters(), lr=lr)

        # replay buffer, memory
        self.memory = ReplayBuffer(capacity, batch_size, device)
Example #2
0
    def __init__(
            self,
            s_dim,
            a_dim,
            bound,
            device,
            capacity,
            batch_size,
            lr,
            gamma,
            tau,
            log_prob_reg
    ):
        # Parameter Initialization
        self.s_dim = s_dim
        self.a_dim = a_dim
        self.bound = bound
        self.device = device
        self.lr = lr
        self.capacity = capacity
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.log_prob_reg = log_prob_reg

        hidden = 256
        # Network
        self.q_net = QNet(s_dim, a_dim, hidden).to(device)
        self.target_q_net = QNet(s_dim, a_dim, hidden).to(device)
        self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.opt_q = torch.optim.Adam(self.q_net.parameters(), lr=lr)
        self.policy_net = PolicyNet(s_dim, a_dim, hidden).to(device)
        self.opt_policy = torch.optim.Adam(self.policy_net.parameters(), lr=lr)
        # alpha
        self.alpha = 1
        self.target_entropy = -a_dim
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=lr)


        # replay buffer, memory
        self.memory = ReplayBuffer(capacity, batch_size, device)
Example #3
0
def main():
    env = CreateBreakout()
    buffer = ReplayBuffer(buffer_capacity)
    behaviourNet = QNet().to(device)
    #behaviourNet.load_state_dict(torch.load(model_path))
    targetNet = QNet().to(device)
    targetNet.load_state_dict(behaviourNet.state_dict())
    optimizer = torch.optim.Adam(behaviourNet.parameters(), learning_rate)
    
    score_history = []
    train_history = []
    #train_history = np.load(history_path+'.npy').tolist()

    step = 0
    score = 0

    state = env.reset()

    print("Train start")
    while step < Train_max_step:
        epsilon = max(0.1, 1.0 - (0.9/final_exploration_step) * step)

        action_value = behaviourNet(torch.FloatTensor([state]).to(device))

        # epsilon greedy
        coin = random.random()
        if coin < epsilon:
            action = random.randrange(4)
        else:
            action = action_value.argmax().item()
        
        next_state, reward, done, info = env.step(action)
        buffer.push((state, action, reward, next_state, 1-done))

        score += reward
        step += 1

        if done:
            next_state = env.reset()
            score_history.append(score)
            score = 0
            if len(score_history)> 100:
                del score_history[0]
        
        state = next_state

        if step%update_frequency==0 and buffer.size() > replay_start_size:
            s_batch, a_batch, r_batch, s_prime_batch, done_batch = buffer.sample(batch_size)
            train(optimizer, behaviourNet, targetNet, s_batch, a_batch, r_batch, s_prime_batch, done_batch)

        if step % update_interval==0 and buffer.size() > replay_start_size:
            targetNet.load_state_dict(behaviourNet.state_dict())

        if step % save_interval == 0:
            train_history.append(mean(score_history))
            torch.save(behaviourNet.state_dict(), model_path)
            np.save(history_path, np.array(train_history))
            print("step : {}, Average score of last 100 episode : {:.1f}".format(step, mean(score_history)))
    
    torch.save(behaviourNet.state_dict(), model_path)
    np.save(history_path, np.array(train_history))
    print("Train end, avg_score of last 100 episode : {}".format(mean(score_history)))
Example #4
0
class DIAYN:
    def __init__(self, s_dim, a_num, skill_num, hidden, lr, gamma, tau,
                 log_prob_reg, alpha, capacity, batch_size, device):
        self.s_dim = s_dim
        self.a_num = a_num
        self.skill_num = skill_num
        hidden = hidden
        self.lr = lr
        self.gamma = gamma
        self.tau = tau
        self.log_prob_reg = log_prob_reg
        self.alpha = alpha
        self.capacity = capacity
        self.batch_size = batch_size
        self.device = device
        self.log_pz = torch.log(
            torch.tensor(1 / skill_num, dtype=torch.float, device=device))

        # network initialization
        self.policy = Policy(s_dim, skill_num, hidden, a_num).to(device)
        self.opt_policy = torch.optim.Adam(self.policy.parameters(), lr=lr)

        self.q_net = QNet(s_dim, skill_num, hidden, a_num).to(device)
        self.opt_q_net = torch.optim.Adam(self.q_net.parameters(), lr=lr)

        self.v_net = VNet(s_dim, skill_num, hidden).to(device)
        self.v_net_target = VNet(s_dim, skill_num, hidden).to(device)
        self.v_net_target.load_state_dict(self.v_net.state_dict())
        self.opt_v_net = torch.optim.Adam(self.v_net.parameters(), lr=lr)

        self.discriminator = Discriminator(s_dim, skill_num, hidden).to(device)
        self.opt_discriminator = torch.optim.Adam(
            self.discriminator.parameters(), lr=lr)

        # replay buffer, memory
        self.memory = ReplayBuffer(capacity, batch_size, device)

    def get_action(self, s, z):
        s = torch.tensor(s, dtype=torch.float, device=self.device)
        z = torch.tensor(z, dtype=torch.float, device=self.device)
        prob = self.policy(s, z)
        dist = Categorical(prob)
        a = dist.sample()
        return a.item()

    def get_pseudo_reward(self, s, z, a, s_):
        s = torch.tensor(s, dtype=torch.float, device=self.device)
        z = torch.tensor(z, dtype=torch.float, device=self.device)
        a = torch.tensor(a, dtype=torch.long, device=self.device)
        s_ = torch.tensor(s_, dtype=torch.float, device=self.device)

        pseudo_reward = self.discriminator(s_,log=True)[z.argmax(dim=-1)] - \
                        self.log_pz + \
                        self.alpha*self.policy(s,z)[a]

        return pseudo_reward.detach().item()

    def learn(self):
        index = torch.tensor(range(self.batch_size),
                             dtype=torch.long,
                             device=self.device)
        s, z, a, s_, r, done = self.memory.get_sample()
        # soft-actor-critic update
        # update q net
        q = self.q_net(s, z)[index, a].unsqueeze(dim=-1)
        v_ = self.v_net_target(s_, z)
        q_target = r + (1 - done) * self.gamma * v_
        q_loss = F.mse_loss(q, q_target.detach())

        self.opt_q_net.zero_grad()
        q_loss.backward()
        self.opt_q_net.step()

        # update v net
        v = self.v_net(s, z)
        log_prob = self.policy(s, z, log=True)[index, a].unsqueeze(dim=-1)
        q_new = self.q_net(s, z)[index, a].unsqueeze(dim=-1)
        v_target = q_new - log_prob
        v_loss = F.mse_loss(v, v_target.detach())

        self.opt_v_net.zero_grad()
        v_loss.backward()
        self.opt_v_net.step()

        # update policy net
        policy_loss = F.mse_loss(log_prob, q_new.detach())
        self.opt_policy.zero_grad()
        policy_loss.backward()
        self.opt_policy.step()

        # update target net
        self.soft_update(self.v_net_target, self.v_net)

        # update discriminator
        log_q_zs = self.discriminator(s, log=True)
        discriminator_loss = F.nll_loss(log_q_zs, z.argmax(dim=-1))
        self.opt_discriminator.zero_grad()
        discriminator_loss.backward()
        self.opt_discriminator.step()

    def soft_update(self, target, source):
        for target_param, param in zip(target.parameters(),
                                       source.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - self.tau) +
                                    param.data * self.tau)
Example #5
0
class SAC:
    def __init__(
            self,
            s_dim,
            a_dim,
            bound,
            device,
            capacity,
            batch_size,
            lr,
            gamma,
            tau,
            log_prob_reg
    ):
        # Parameter Initialization
        self.s_dim = s_dim
        self.a_dim = a_dim
        self.bound = bound
        self.device = device
        self.lr = lr
        self.capacity = capacity
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.log_prob_reg = log_prob_reg

        hidden = 256
        # Network
        self.q_net = QNet(s_dim, a_dim, hidden).to(device)
        self.target_q_net = QNet(s_dim, a_dim, hidden).to(device)
        self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.opt_q = torch.optim.Adam(self.q_net.parameters(), lr=lr)
        self.policy_net = PolicyNet(s_dim, a_dim, hidden).to(device)
        self.opt_policy = torch.optim.Adam(self.policy_net.parameters(), lr=lr)
        # alpha
        self.alpha = 1
        self.target_entropy = -a_dim
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=lr)


        # replay buffer, memory
        self.memory = ReplayBuffer(capacity, batch_size, device)

    def get_action(self, s):
        s = torch.tensor(data=s, dtype=torch.float, device=self.device)
        mean, std = self.policy_net(s)

        normal = Normal(mean, std)
        z = normal.rsample()
        a = torch.tanh(z)

        return self.bound*a.detach().item()

    def get_logprob(self, s, log_reg=1e-6):
        mean, std = self.policy_net(s)

        dist = Normal(mean, std)
        # 不要加random act,算法会变得不幸
        u = dist.rsample()
        a = torch.tanh(u)

        log_prob = dist.log_prob(u) - torch.log(1 - a.pow(2) + log_reg)
        log_prob = log_prob.sum(-1, keepdim=True)
        return a, log_prob

    def learn(self):
        # samples from memory
        s, a, s_, r, done = self.memory.get_sample()

        # update q net
        q = self.q_net(s, a)
        a_, log_prob_ = self.get_logprob(s_)
        q_ = self.target_q_net(s_, a_)
        q_target = r + (1 - done) * self.gamma * (q_ - self.alpha*log_prob_)
        q_loss = F.mse_loss(q, q_target.detach())

        self.opt_q.zero_grad()
        q_loss.backward()
        self.opt_q.step()

        # update policy net
        a_new, log_prob_new = self.get_logprob(s)
        q_new = self.q_net(s, a_new)
        # both loss_functions are available
        # policy_loss = F.mse_loss(log_prob, q_new)
        policy_loss = torch.mean(self.alpha*log_prob_new - q_new)

        self.opt_policy.zero_grad()
        policy_loss.backward()
        self.opt_policy.step()

        # update temperature alpha
        # 如果我们直接设定成固定值,也是可行的
        alpha_loss = -torch.mean(self.log_alpha * (log_prob_new+self.target_entropy).detach())
        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()
        self.alpha = self.log_alpha.exp().item()

        # update target net
        self.soft_update(self.target_q_net, self.q_net)

    def soft_update(self, target, source):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(
                target_param.data * (1.0 - self.tau) + param.data * self.tau
            )
Example #6
0
class SAC:
    def __init__(
            self,
            s_dim,
            a_dim,
            bound,
            device,
            capacity,
            batch_size,
            lr,
            gamma,
            tau,
            log_prob_reg
    ):
        # Parameter Initialization
        self.s_dim = s_dim
        self.a_dim = a_dim
        self.bound = bound
        self.device = device
        self.lr = lr
        self.capacity = capacity
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.log_prob_reg = log_prob_reg

        hidden = 256
        # Network
        self.v_net = VNet(s_dim, hidden).to(device)
        self.target_v_net = VNet(s_dim, hidden).to(device)
        self.target_v_net.load_state_dict(self.v_net.state_dict())
        self.opt_v = torch.optim.Adam(self.v_net.parameters(), lr=lr)

        self.q_net = QNet(s_dim, a_dim, hidden).to(device)
        self.opt_q = torch.optim.Adam(self.q_net.parameters(), lr=lr)
        self.policy_net = PolicyNet(s_dim, a_dim, hidden).to(device)
        self.opt_policy = torch.optim.Adam(self.policy_net.parameters(), lr=lr)

        # replay buffer, memory
        self.memory = ReplayBuffer(capacity, batch_size, device)

    def get_action(self, s):
        s = torch.tensor(data=s, dtype=torch.float, device=self.device)
        mean, std = self.policy_net(s)

        normal = Normal(mean, std)
        z = normal.rsample()
        a = torch.tanh(z)

        return self.bound*a.detach().item()

    def get_logprob(self, s, log_reg=1e-6):
        mean, std = self.policy_net(s)

        dist = Normal(mean, std)
        u = dist.rsample()
        a = torch.tanh(u)

        log_prob = dist.log_prob(u) - torch.log(1 - a.pow(2) + log_reg)
        log_prob = log_prob.sum(-1, keepdim=True)

        return a, log_prob

    def learn(self):
        # samples from memory
        s, a, s_, r, done = self.memory.get_sample()

        # update q net
        q = self.q_net(s, a)
        v_ = self.target_v_net(s_)
        q_target = r + (1 - done) * self.gamma * v_
        q_loss = F.mse_loss(q, q_target.detach())

        self.opt_q.zero_grad()
        q_loss.backward()
        self.opt_q.step()

        # update v net
        v = self.v_net(s)
        new_a, log_prob = self.get_logprob(s)
        q_new = self.q_net(s, new_a)
        v_target = q_new - log_prob
        value_loss = F.mse_loss(v, v_target.detach())

        self.opt_v.zero_grad()
        value_loss.backward()
        self.opt_v.step()

        # update policy net
        # both loss_functions are available
        # policy_loss = F.mse_loss(log_prob, q_new)
        policy_loss = torch.mean(log_prob - q_new)

        self.opt_policy.zero_grad()
        policy_loss.backward()
        self.opt_policy.step()

        # update target net
        self.soft_update(self.target_v_net, self.v_net)

    def soft_update(self, target, source):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(
                target_param.data * (1.0 - self.tau) + param.data * self.tau
            )