예제 #1
0
def main():
    #Parse arguments
    #----------------------------
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", default="CartPole-v0")
    parser.add_argument("--conti", action="store_true")
    args = parser.parse_args()

    #Parameters
    #----------------------------
    env_id = args.env
    mb_size = 256
    lr = 1e-5
    n_iter = 100000
    disp_step = 1000
    save_step = 10000
    save_dir = "./save"
    device = "cuda:0"
    expert_path = "../save/{}_traj.pkl".format(args.env)

    #Create environment
    #----------------------------
    env = gym.make(env_id)

    if args.conti:
        s_dim = env.observation_space.shape[0]
        a_dim = env.action_space.shape[0]
    else:
        s_dim = env.observation_space.shape[0]
        a_dim = env.action_space.n

    #Load expert trajectories
    #----------------------------
    if os.path.exists(expert_path):
        s_traj, a_traj = pkl.load(open(expert_path, "rb"))
        s_traj = np.concatenate(s_traj, 0)
        a_traj = np.concatenate(a_traj, 0)
    else:
        print("ERROR: No expert trajectory file found")
        sys.exit(1)

    #Create model
    #----------------------------
    policy_net = PolicyNet(s_dim, a_dim, conti=args.conti).to(device)
    opt = torch.optim.Adam(policy_net.parameters(), lr)

    #Load model
    #----------------------------
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    if os.path.exists(os.path.join(save_dir, "{}.pt".format(args.env))):
        print("Loading the model ... ", end="")
        checkpoint = torch.load(
            os.path.join(save_dir, "{}.pt".format(args.env)))
        policy_net.load_state_dict(checkpoint["PolicyNet"])
        start_it = checkpoint["it"]
        print("Done.")
    else:
        start_it = 0

    #Start training
    #----------------------------
    t_start = time.time()
    policy_net.train()

    for it in range(start_it, n_iter + 1):
        #Train
        mb_obs, mb_actions = sample_batch(s_traj, a_traj, mb_size)
        mb_a_logps, mb_ents = policy_net.evaluate(
            torch.from_numpy(mb_obs).to(device),
            torch.from_numpy(mb_actions).to(device))
        loss = -mb_a_logps.mean()

        opt.zero_grad()
        loss.backward()
        opt.step()

        #Print the result
        if it % disp_step == 0:
            print("[{:5d} / {:5d}] Elapsed time = {:.2f}, actor loss = {:.6f}".
                  format(it, n_iter,
                         time.time() - t_start, loss.item()))

        #Save model
        if it % save_step == 0:
            print("Saving the model ... ", end="")
            torch.save({
                "it": it,
                "PolicyNet": policy_net.state_dict()
            }, os.path.join(save_dir, "{}.pt".format(args.env)))
            print("Done.")
            print()

    env.close()
예제 #2
0
class SAC:
    def __init__(self, env, gamma, tau, buffer_maxlen, value_lr, q_lr, policy_lr):

        self.env = env
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]
        self.action_range = [env.action_space.low, env.action_space.high]

        # hyperparameters
        self.gamma = gamma
        self.tau = tau

        # initialize networks
        self.value_net = ValueNet(self.state_dim).to(device)
        self.target_value_net = ValueNet(self.state_dim).to(device)
        self.q1_net = SoftQNet(self.state_dim, self.action_dim).to(device)
        self.q2_net = SoftQNet(self.state_dim, self.action_dim).to(device)
        self.policy_net = PolicyNet(self.state_dim, self.action_dim).to(device)

        # Load the target value network parameters
        for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()):
            target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)

            # Initialize the optimizer
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=value_lr)
        self.q1_optimizer = optim.Adam(self.q1_net.parameters(), lr=q_lr)
        self.q2_optimizer = optim.Adam(self.q2_net.parameters(), lr=q_lr)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr)

        # Initialize thebuffer
        self.buffer = ReplayBeffer(buffer_maxlen)

    def get_action(self, state):
        action = self.policy_net.action(state)
        action = action * (self.action_range[1] - self.action_range[0]) / 2.0 + \
                 (self.action_range[1] + self.action_range[0]) / 2.0

        return action

    def update(self, batch_size):
        state, action, reward, next_state, done = self.buffer.sample(batch_size)
        new_action, log_prob = self.policy_net.evaluate(state)

        # V value loss
        value = self.value_net(state)
        new_q1_value = self.q1_net(state, new_action)
        new_q2_value = self.q2_net(state, new_action)
        next_value = torch.min(new_q1_value, new_q2_value) - log_prob
        value_loss = F.mse_loss(value, next_value.detach())

        # Soft q  loss
        q1_value = self.q1_net(state, action)
        q2_value = self.q2_net(state, action)
        target_value = self.target_value_net(next_state)
        target_q_value = reward + done * self.gamma * target_value
        q1_value_loss = F.mse_loss(q1_value, target_q_value.detach())
        q2_value_loss = F.mse_loss(q2_value, target_q_value.detach())

        # Policy loss
        policy_loss = (log_prob - torch.min(new_q1_value, new_q2_value)).mean()

        # Update v
        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()

        # Update Soft q
        self.q1_optimizer.zero_grad()
        self.q2_optimizer.zero_grad()
        q1_value_loss.backward()
        q2_value_loss.backward()
        self.q1_optimizer.step()
        self.q2_optimizer.step()

        # Update Policy
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # Update target networks
        for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()):
            target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)