Beispiel #1
0
def sac_opp(
    global_ac,
    global_ac_targ,
    global_cpc,
    rank,
    T,
    E,
    args,
    scores,
    wins,
    buffer,
    device=None,
    tensorboard_dir=None,
):
    torch.manual_seed(args.seed + rank)
    np.random.seed(args.seed + rank)
    # writer = GlobalSummaryWriter.getSummaryWriter()
    tensorboard_dir = os.path.join(tensorboard_dir, str(rank))
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)
    writer = SummaryWriter(log_dir=tensorboard_dir)
    if args.exp_name == "test":
        env = gym.make("CartPole-v0")
    elif args.non_station:
        env = make_ftg_ram_nonstation(args.env,
                                      p2_list=args.list,
                                      total_episode=args.station_rounds,
                                      stable=args.stable)
    else:
        env = make_ftg_ram(args.env, p2=args.p2)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n
    print("set up child process env")
    local_ac = MLPActorCritic(obs_dim + args.c_dim, act_dim,
                              **dict(hidden_sizes=[args.hid] *
                                     args.l)).to(device)
    local_ac.load_state_dict(global_ac.state_dict())
    print("local ac load global ac")

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    # Async Version
    for p in global_ac_targ.parameters():
        p.requires_grad = False

    # Experience buffer
    if args.cpc:
        replay_buffer = ReplayBufferOppo(obs_dim=obs_dim,
                                         max_size=args.replay_size,
                                         encoder=global_cpc)
    else:
        replay_buffer = ReplayBuffer(obs_dim=obs_dim, size=args.replay_size)

    # Entropy Tuning
    target_entropy = -torch.prod(
        torch.Tensor(env.action_space.shape).to(
            device)).item()  # heuristic value from the paper
    alpha = max(local_ac.log_alpha.exp().item(),
                args.min_alpha) if not args.fix_alpha else args.min_alpha

    # Set up optimizers for policy and q-function
    # Async Version
    pi_optimizer = Adam(global_ac.pi.parameters(), lr=args.lr, eps=1e-4)
    q1_optimizer = Adam(global_ac.q1.parameters(), lr=args.lr, eps=1e-4)
    q2_optimizer = Adam(global_ac.q2.parameters(), lr=args.lr, eps=1e-4)
    cpc_optimizer = Adam(global_cpc.parameters(), lr=args.lr, eps=1e-4)
    alpha_optim = Adam([global_ac.log_alpha], lr=args.lr, eps=1e-4)

    # Prepare for interaction with environment
    o, ep_ret, ep_len = env.reset(), 0, 0
    if args.cpc:
        c_hidden = global_cpc.init_hidden(1, args.c_dim, use_gpu=args.cuda)
        c1, c_hidden = global_cpc.predict(o, c_hidden)
        assert len(c1.shape) == 3
        c1 = c1.flatten().cpu().numpy()
        all_embeddings = []
        meta = []
    trajectory = list()
    p2 = env.p2
    p2_list = [str(p2)]
    discard = False
    uncertainties = []
    local_t, local_e = 0, 0
    t = T.value()
    e = E.value()
    glod_input = list()
    glod_target = list()
    # Main loop: collect experience in env and update/log each epoch
    while e <= args.episode:
        with torch.no_grad():
            # Until start_steps have elapsed, randomly sample actions
            # from a uniform distribution for better exploration. Afterwards,
            # use the learned policy.
            if t > args.start_steps:
                if args.cpc:
                    a = local_ac.get_action(np.concatenate((o, c1), axis=0),
                                            device=device)
                    a_prob = local_ac.act(
                        torch.as_tensor(np.expand_dims(np.concatenate((o, c1),
                                                                      axis=0),
                                                       axis=0),
                                        dtype=torch.float32,
                                        device=device))
                else:
                    a = local_ac.get_action(o, greedy=True, device=device)
                    a_prob = local_ac.act(
                        torch.as_tensor(np.expand_dims(o, axis=0),
                                        dtype=torch.float32,
                                        device=device))
            else:
                a = env.action_space.sample()
                a_prob = local_ac.act(
                    torch.as_tensor(np.expand_dims(o, axis=0),
                                    dtype=torch.float32,
                                    device=device))
        uncertainty = ood_scores(a_prob).item()

        # Step the env
        o2, r, d, info = env.step(a)
        if info.get('no_data_receive', False):
            discard = True
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if (ep_len == args.max_ep_len) or discard else d
        glod_input.append(o), glod_target.append(a)

        if args.cpc:
            # changed the trace structure for further analysis
            c2, c_hidden = global_cpc.predict(o2, c_hidden)
            assert len(c2.shape) == 3
            c2 = c2.flatten().cpu().numpy()
            replay_buffer.store(np.concatenate((o, c1), axis=0), a, r,
                                np.concatenate((o2, c2), axis=0), d)
            trajectory.append([o, a, r, o2, d, c1, c2, ep_len])
            all_embeddings.append(c1)
            meta.append([env.p2, local_e, ep_len, r, a, uncertainty])
            c1 = c2
            trajectory.append([o, a, r, o2, d, c1, c2])
        else:
            replay_buffer.store(o, a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2
        T.increment()
        t = T.value()
        local_t += 1

        # End of trajectory handling
        if d or (ep_len == args.max_ep_len) or discard:
            replay_buffer.store(trajectory)
            E.increment()
            e = E.value()
            local_e += 1
            # logger.store(EpRet=ep_ret, EpLen=ep_len)
            if info.get('win', False):
                wins.append(1)
            else:
                wins.append(0)
            scores.append(ep_ret)
            m_score = np.mean(scores[-100:])
            win_rate = np.mean(wins[-100:])
            print(
                "Process {}, opponent:{}, # of global_episode :{},  # of global_steps :{}, round score: {}, mean score : {:.1f}, win_rate:{}, steps: {}, alpha: {}"
                .format(rank, args.p2, e, t, ep_ret, m_score, win_rate, ep_len,
                        alpha))
            writer.add_scalar("metrics/round_score", ep_ret, e)
            writer.add_scalar("metrics/mean_score", m_score.item(), e)
            writer.add_scalar("metrics/win_rate", win_rate.item(), e)
            writer.add_scalar("metrics/round_step", ep_len, e)
            writer.add_scalar("metrics/alpha", alpha, e)

            # CPC update handing
            if local_e > args.batch_size and local_e % args.update_every == 0 and args.cpc:
                data, indexes, min_len = replay_buffer.sample_traj(
                    args.batch_size)
                global_cpc.train()
                cpc_optimizer.zero_grad()
                c_hidden = global_cpc.init_hidden(len(data),
                                                  args.c_dim,
                                                  use_gpu=args.cuda)
                acc, loss, latents = global_cpc(data, c_hidden)

                replay_buffer.update_latent(indexes, min_len, latents.detach())
                loss.backward()
                # add gradient clipping
                nn.utils.clip_grad_norm_(global_cpc.parameters(), 20)
                cpc_optimizer.step()

                writer.add_scalar("training/acc", acc, e)
                writer.add_scalar("training/cpc_loss", loss.detach().item(), e)

                all_embeddings = np.array(all_embeddings)
                writer.add_embedding(mat=all_embeddings,
                                     metadata=meta,
                                     metadata_header=[
                                         "opponent", "round", "step", "reward",
                                         "action", "uncertainty"
                                     ])
                c_hidden = global_cpc.init_hidden(1,
                                                  args.c_dim,
                                                  use_gpu=args.cuda)
            o, ep_ret, ep_len = env.reset(), 0, 0
            trajectory = list()
            discard = False

        # OOD update stage
        if (t >= args.ood_update_step and local_t % args.ood_update_step == 0
                or replay_buffer.is_full()) and args.ood:
            # used all the data collected from the last args.ood_update_steps as the train data
            print("Conduct OOD updating")
            ood_train = (glod_input, glod_target)
            glod_model = convert_to_glod(global_ac.pi,
                                         train_loader=ood_train,
                                         hidden_dim=args.hid,
                                         act_dim=act_dim,
                                         device=device)
            glod_scores = retrieve_scores(
                glod_model,
                replay_buffer.obs_buf[:replay_buffer.size],
                device=device,
                k=args.ood_K)
            glod_scores = glod_scores.detach().cpu().numpy()
            print(len(glod_scores))
            writer.add_histogram(values=glod_scores,
                                 max_bins=300,
                                 global_step=local_t,
                                 tag="OOD")
            drop_points = np.percentile(
                a=glod_scores, q=[args.ood_drop_lower, args.ood_drop_upper])
            lower, upper = drop_points[0], drop_points[1]
            print(lower, upper)
            mask = np.logical_and((glod_scores >= lower),
                                  (glod_scores <= upper))
            reserved_indexes = np.argwhere(mask).flatten()
            print(len(reserved_indexes))
            if len(reserved_indexes) > 0:
                replay_buffer.ood_drop(reserved_indexes)
                glod_input = list()
                glod_target = list()

        # SAC Update handling
        if local_t >= args.update_after and local_t % args.update_every == 0:
            for j in range(args.update_every):

                batch = replay_buffer.sample_trans(batch_size=args.batch_size,
                                                   device=device)
                # First run one gradient descent step for Q1 and Q2
                q1_optimizer.zero_grad()
                q2_optimizer.zero_grad()
                loss_q = local_ac.compute_loss_q(batch, global_ac_targ,
                                                 args.gamma, alpha)
                loss_q.backward()

                # Next run one gradient descent step for pi.
                pi_optimizer.zero_grad()
                loss_pi, entropy = local_ac.compute_loss_pi(batch, alpha)
                loss_pi.backward()

                alpha_optim.zero_grad()
                alpha_loss = -(local_ac.log_alpha *
                               (entropy + target_entropy).detach()).mean()
                alpha_loss.backward(retain_graph=False)
                alpha = max(
                    local_ac.log_alpha.exp().item(),
                    args.min_alpha) if not args.fix_alpha else args.min_alpha

                nn.utils.clip_grad_norm_(local_ac.parameters(), 20)
                for global_param, local_param in zip(global_ac.parameters(),
                                                     local_ac.parameters()):
                    global_param._grad = local_param.grad

                pi_optimizer.step()
                q1_optimizer.step()
                q2_optimizer.step()
                alpha_optim.step()

                state_dict = global_ac.state_dict()
                local_ac.load_state_dict(state_dict)

                # Finally, update target networks by polyak averaging.
                with torch.no_grad():
                    for p, p_targ in zip(global_ac.parameters(),
                                         global_ac_targ.parameters()):
                        p_targ.data.copy_((1 - args.polyak) * p.data +
                                          args.polyak * p_targ.data)

                writer.add_scalar("training/pi_loss",
                                  loss_pi.detach().item(), t)
                writer.add_scalar("training/q_loss", loss_q.detach().item(), t)
                writer.add_scalar("training/alpha_loss",
                                  alpha_loss.detach().item(), t)
                writer.add_scalar("training/entropy",
                                  entropy.detach().mean().item(), t)

        if t % args.save_freq == 0 and t > 0:
            torch.save(
                global_ac.state_dict(),
                os.path.join(args.save_dir, args.exp_name, args.model_para))
            torch.save(
                global_cpc.state_dict(),
                os.path.join(args.save_dir, args.exp_name, args.cpc_para))
            state_dict_trans(
                global_ac.state_dict(),
                os.path.join(args.save_dir, args.exp_name, args.numpy_para))
            torch.save((e, t, list(scores), list(wins)),
                       os.path.join(args.save_dir, args.exp_name,
                                    args.train_indicator))
            print("Saving model at episode:{}".format(t))
Beispiel #2
0
def train(global_model, rank, T, scores):
    env = make_ftg_ram(env_name, p2=p2)
    state_shape = env.observation_space.shape[0]
    action_shape = env.action_space.n
    local_model = ActorCritic(state_shape, action_shape, hidden_size)
    local_model.load_state_dict(global_model.state_dict())
    # MODEL_STATE = "/home/byron/Repos/FTG4.50/OpenAI/ByronAI.numpy"
    # test_model = ActorCriticNumpy(MODEL_STATE)
    optimizer = optim.Adam(global_model.parameters(), lr=learning_rate)

    while True:
        discard = False
        done = False
        s = env.reset()
        score = 0
        sum_entropy = 0
        step = 0
        while not done:
            s_lst, a_lst, r_lst = [], [], []
            for t in range(update_interval):
                prob = local_model.pi(torch.from_numpy(s).float())
                # test_prob = test_model.pi(s)
                # diff = prob.detach().numpy() - test_prob
                # print(diff.sum())
                m = Categorical(prob)
                a = m.sample().item()
                s_prime, r, done, info = env.step(a)
                if info.get('no_data_receive', False):
                    discard = True
                    break

                s_lst.append(s)
                a_lst.append([a])
                # r_lst.append(r/100.0)
                r_lst.append(r)

                s = s_prime

                score += r
                sum_entropy += Categorical(probs=prob.detach()).entropy()
                step += 1
                if done:
                    break
            if discard:
                break
            s_final = torch.tensor(s_prime, dtype=torch.float)
            R = 0.0 if done else local_model.v(s_final).item()
            td_target_lst = []
            for reward in r_lst[::-1]:
                R = gamma * R + reward
                td_target_lst.append([R])
            td_target_lst.reverse()

            s_batch, a_batch, td_target = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                torch.tensor(td_target_lst)
            advantage = td_target - local_model.v(s_batch)

            pi = local_model.pi(s_batch, softmax_dim=1)
            pi_a = pi.gather(1, a_batch)
            loss = -torch.log(pi_a) * advantage.detach() + \
                F.smooth_l1_loss(local_model.v(s_batch), td_target.detach()) - \
                   (entropy_weight * -(torch.log(pi) * pi).sum())

            optimizer.zero_grad()
            loss.mean().backward()
            for global_param, local_param in zip(global_model.parameters(),
                                                 local_model.parameters()):
                global_param._grad = local_param.grad
            optimizer.step()
            local_model.load_state_dict(global_model.state_dict())
        if discard:
            continue
        T.increment()
        t = T.value()
        scores.append(score)
        m_score = np.mean(scores[-100:])
        print(
            "Process {}, # of episode :{}, round score: {}, mean score : {:.1f}, entropy: {}, steps: {}"
            .format(rank, t, score, m_score, sum_entropy / step, step))
        if t % save_interval == 0 and t > 0:
            torch.save(global_model.state_dict(),
                       os.path.join(save_dir, "model"))
            state_dict_trans(global_model.state_dict(),
                             os.path.join(save_dir, numpy_para))
            print("Saving model at episode:{}".format(t))
Beispiel #3
0
def sac(
    global_ac,
    global_ac_targ,
    rank,
    T,
    E,
    args,
    scores,
    wins,
    buffer_q,
    device=None,
    tensorboard_dir=None,
):
    torch.manual_seed(args.seed + rank)
    np.random.seed(args.seed + rank)
    # writer = GlobalSummaryWriter.getSummaryWriter()
    tensorboard_dir = os.path.join(tensorboard_dir, str(rank))
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)
    writer = SummaryWriter(log_dir=tensorboard_dir)
    # if args.exp_name == "test":
    #     env = gym.make("CartPole-v0")
    # elif args.non_station:
    #     env = make_ftg_ram_nonstation(args.env, p2_list=args.list, total_episode=args.station_rounds,stable=args.stable)
    # else:
    #     env = make_ftg_ram(args.env, p2=args.p2)
    # obs_dim = env.observation_space.shape[0]
    # act_dim = env.action_space.n
    env = Soccer()
    # env = gym.make("CartPole-v0")
    obs_dim = env.n_features
    act_dim = env.n_actions
    print("set up child process env")
    local_ac = MLPActorCritic(obs_dim, act_dim,
                              **dict(hidden_sizes=[args.hid] *
                                     args.l)).to(device)
    state_dict = global_ac.state_dict()
    local_ac.load_state_dict(state_dict)
    print("local ac load global ac")

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    # Async Version
    for p in global_ac_targ.parameters():
        p.requires_grad = False

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim, size=args.replay_size)
    training_buffer = ReplayBuffer(obs_dim=obs_dim, size=args.replay_size)

    # Entropy Tuning
    target_entropy = -np.log((1.0 / act_dim)) * 0.5
    alpha = max(local_ac.log_alpha.exp().item(),
                args.min_alpha) if not args.fix_alpha else args.min_alpha

    # Set up optimizers for policy and q-function
    # Async Version
    pi_optimizer = Adam(global_ac.pi.parameters(), lr=args.lr, eps=1e-4)
    q1_optimizer = Adam(global_ac.q1.parameters(), lr=args.lr, eps=1e-4)
    q2_optimizer = Adam(global_ac.q2.parameters(), lr=args.lr, eps=1e-4)
    alpha_optim = Adam([global_ac.log_alpha], lr=args.lr, eps=1e-4)

    # Prepare for interaction with environment
    o, ep_ret, ep_len = env.reset(), 0, 0
    discard = False
    glod_model = None
    glod_lower = None
    glod_upper = None
    last_updated = 0
    saved_e = 0
    t = T.value()
    e = E.value()
    local_t, local_e = 0, 0
    # Main loop: collect experience in env and update/log each epoch
    while e <= args.episode:

        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy.
        with torch.no_grad():
            a = local_ac.get_action(o, device=device)
        if hasattr(env, 'p2'):
            p2 = env.p2
        else:
            p2 = None
        # Step the env
        o2, r, d, info = env.step(a, np.random.randint(act_dim))
        env.render()

        if info.get('no_data_receive', False):
            discard = True
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        # d = False if (ep_len == args.max_ep_len) or discard else d

        # Store experience to replay buffer
        if glod_model is None or not args.ood:
            replay_buffer.store(o, a, r, o2, d, str(p2))
            training_buffer.store(o, a, r, o2, d, str(p2))
        else:
            obs_glod_score = retrieve_scores(glod_model,
                                             np.expand_dims(o, axis=0),
                                             device=torch.device("cpu"),
                                             k=args.ood_K)
            if glod_lower <= obs_glod_score <= glod_upper:
                training_buffer.store(o, a, r, o2, d, str(p2))
            replay_buffer.store(o, a, r, o2, d, str(p2))
        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        T.increment()
        t = T.value()
        local_t += 1

        # End of trajectory handling
        if d or (ep_len == args.max_ep_len) or discard:
            E.increment()
            e = E.value()
            local_e += 1
            # logger.store(EpRet=ep_ret, EpLen=ep_len)
            if info.get('win', False):
                wins.append(1)
            else:
                wins.append(0)
            scores.append(ep_ret)
            m_score = np.mean(scores[-100:])
            win_rate = np.mean(wins[-100:])
            print(
                "Process {}, opponent:{}, # of global_episode :{},  # of global_steps :{}, round score: {}, mean score : {:.1f}, win_rate:{}, steps: {}, alpha: {}"
                .format(rank, args.p2, e, t, ep_ret, m_score, win_rate, ep_len,
                        alpha))
            writer.add_scalar("metrics/round_score", ep_ret, e)
            writer.add_scalar("metrics/mean_score", m_score.item(), e)
            writer.add_scalar("metrics/win_rate", win_rate.item(), e)
            writer.add_scalar("metrics/round_step", ep_len, e)
            writer.add_scalar("metrics/alpha", alpha, e)
            o, ep_ret, ep_len = env.reset(), 0, 0
            discard = False

        # OOD update stage, can only use CPU as the GPU memory can not hold so much data
        if local_e >= args.ood_starts and local_e % args.ood_update_rounds == 0 and args.ood and local_e != last_updated:
            print("OOD updating at rounds {}".format(e))
            print("Replay Buffer Size: {}, Training Buffer Size: {}".format(
                replay_buffer.size, training_buffer.size))
            glod_idxs = np.random.randint(0,
                                          training_buffer.size,
                                          size=int(training_buffer.size *
                                                   args.ood_train_per))
            glod_input = training_buffer.obs_buf[glod_idxs]
            glod_target = training_buffer.act_buf[glod_idxs]
            ood_train = (glod_input, glod_target)
            glod_model = deepcopy(global_ac.pi).cpu()
            glod_model = convert_to_glod(glod_model,
                                         train_loader=ood_train,
                                         hidden_dim=args.hid,
                                         act_dim=act_dim,
                                         device=torch.device("cpu"))
            training_buffer = deepcopy(replay_buffer)
            glod_scores = retrieve_scores(
                glod_model,
                replay_buffer.obs_buf[:training_buffer.size],
                device=torch.device("cpu"),
                k=args.ood_K)
            glod_scores = glod_scores.detach().cpu().numpy()
            glod_p2 = training_buffer.p2_buf[:training_buffer.size]
            drop_points = np.percentile(
                a=glod_scores, q=[args.ood_drop_lower, args.ood_drop_upper])
            glod_lower, glod_upper = drop_points[0], drop_points[1]
            mask = np.logical_and((glod_scores >= glod_lower),
                                  (glod_scores <= glod_upper))
            reserved_indexes = np.argwhere(mask).flatten()
            if len(reserved_indexes) > 0:
                training_buffer.ood_drop(reserved_indexes)
            writer.add_histogram(values=glod_scores,
                                 max_bins=300,
                                 global_step=local_e,
                                 tag="OOD")
            print("Replay Buffer Size: {}, Training Buffer Size: {}".format(
                replay_buffer.size, training_buffer.size))
            torch.save(
                (glod_scores, replay_buffer.p2_buf[:replay_buffer.size]),
                os.path.join(args.save_dir, args.exp_name,
                             "glod_info_{}_{}".format(rank, local_e)))
            last_updated = local_e

        # SAC Update handling
        if local_e >= args.update_after and local_t % args.update_every == 0:
            for j in range(args.update_every):
                batch = training_buffer.sample_trans(args.batch_size,
                                                     device=device)
                # First run one gradient descent step for Q1 and Q2
                q1_optimizer.zero_grad()
                q2_optimizer.zero_grad()
                loss_q = local_ac.compute_loss_q(batch, global_ac_targ,
                                                 args.gamma, alpha)
                loss_q.backward()
                nn.utils.clip_grad_norm_(global_ac.parameters(),
                                         max_norm=20,
                                         norm_type=2)
                q1_optimizer.step()
                q2_optimizer.step()

                # Next run one gradient descent step for pi.
                pi_optimizer.zero_grad()
                loss_pi, entropy = local_ac.compute_loss_pi(batch, alpha)
                loss_pi.backward()
                nn.utils.clip_grad_norm_(global_ac.parameters(),
                                         max_norm=20,
                                         norm_type=2)
                pi_optimizer.step()

                alpha_optim.zero_grad()
                alpha_loss = -(local_ac.log_alpha *
                               (entropy + target_entropy).detach()).mean()
                alpha_loss.backward(retain_graph=False)
                alpha = max(
                    local_ac.log_alpha.exp().item(),
                    args.min_alpha) if not args.fix_alpha else args.min_alpha
                nn.utils.clip_grad_norm_(global_ac.parameters(),
                                         max_norm=20,
                                         norm_type=2)
                alpha_optim.step()

                for global_param, local_param in zip(global_ac.parameters(),
                                                     local_ac.parameters()):
                    global_param._grad = local_param.grad

                state_dict = global_ac.state_dict()
                local_ac.load_state_dict(state_dict)

                # Finally, update target networks by polyak averaging.
                with torch.no_grad():
                    for p, p_targ in zip(global_ac.parameters(),
                                         global_ac_targ.parameters()):
                        p_targ.data.copy_((1 - args.polyak) * p.data +
                                          args.polyak * p_targ.data)

                writer.add_scalar("training/pi_loss",
                                  loss_pi.detach().item(), t)
                writer.add_scalar("training/q_loss", loss_q.detach().item(), t)
                writer.add_scalar("training/alpha_loss",
                                  alpha_loss.detach().item(), t)
                writer.add_scalar("training/entropy",
                                  entropy.detach().mean().item(), t)

        if e % args.save_freq == 0 and e > 0 and e != saved_e:
            torch.save(
                global_ac.state_dict(),
                os.path.join(args.save_dir, args.exp_name,
                             "model_torch_{}".format(e)))
            state_dict_trans(
                global_ac.state_dict(),
                os.path.join(args.save_dir, args.exp_name,
                             "model_numpy_{}".format(e)))
            torch.save((e, t, list(scores), list(wins)),
                       os.path.join(args.save_dir, args.exp_name,
                                    "model_data_{}".format(e)))
            print("Saving model at episode:{}".format(t))
            saved_e = e
Beispiel #4
0
def sac(env_fn, actor_critic=MLPActorCritic, ac_kwargs=dict(), seed=0,
        steps_per_epoch=4000, epochs=100, replay_size=int(1e6), gamma=0.99,
        polyak=0.995, lr=1e-3, alpha=0.2, batch_size=100, start_steps=10000,
        update_after=1000, update_every=50, num_test_episodes=10, max_ep_len=1000,
        logger_kwargs=dict(), save_freq=1000, save_dir=None):
    """
    Soft Actor-Critic (SAC)


    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The constructor method for a PyTorch Module with an ``act``
            method, a ``pi`` module, a ``q1`` module, and a ``q2`` module.
            The ``act`` method and ``pi`` module should accept batches of
            observations as inputs, and ``q1`` and ``q2`` should accept a batch
            of observations and a batch of actions as inputs. When called,
            ``act``, ``q1``, and ``q2`` should return:

            ===========  ================  ======================================
            Call         Output Shape      Description
            ===========  ================  ======================================
            ``act``      (batch, act_dim)  | Numpy array of actions for each
                                           | observation.
            ``q1``       (batch,)          | Tensor containing one current estimate
                                           | of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ``q2``       (batch,)          | Tensor containing the other current
                                           | estimate of Q* for the provided observations
                                           | and actions. (Critical: make sure to
                                           | flatten this!)
            ===========  ================  ======================================

            Calling ``pi`` should return:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``a``        (batch, act_dim)  | Tensor containing actions from policy
                                           | given observations.
            ``logp_pi``  (batch,)          | Tensor containing log probabilities of
                                           | actions in ``a``. Importantly: gradients
                                           | should be able to flow back into ``a``.
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object
            you provided to SAC.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target
            networks. Target networks are updated towards main networks
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        alpha (float): Entropy regularization coefficient. (Equivalent to
            inverse of reward scale in the original SAC paper.)

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        update_after (int): Number of env interactions to collect before
            starting to do gradient descent updates. Ensures replay buffer
            is full enough for useful updates.

        update_every (int): Number of env interactions that should elapse
            between gradient descent updates. Note: Regardless of how long
            you wait between updates, the ratio of env steps to gradient steps
            is locked to 1.

        num_test_episodes (int): Number of episodes to test the deterministic
            policy at the end of each epoch.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    torch.manual_seed(seed)
    np.random.seed(seed)

    env, test_env = env_fn(), env_fn()
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    # act_limit = env.action_space.high[0]

    # Create actor-critic module and target networks
    ac = actor_critic(env.observation_space, env.action_space, **ac_kwargs)
    ac_targ = deepcopy(ac)

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False

    # List of parameters for both Q-networks (save this for convenience)

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim, size=replay_size)

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(count_vars(module) for module in [ac.pi, ac.q1, ac.q2])
    logger.log('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' % var_counts)

    # Set up optimizers for policy and q-function
    pi_optimizer = Adam(ac.pi.parameters(), lr=lr)
    q1_optimizer = Adam(ac.q1.parameters(), lr=lr)
    q2_optimizer = Adam(ac.q2.parameters(), lr=lr)

    # Set up model saving
    logger.setup_pytorch_saver(ac)

    # product action
    def get_actions_info(a_prob):
        a_dis = Categorical(a_prob)
        max_a = torch.argmax(a_prob)
        sample_a = a_dis.sample().cpu()
        z = a_prob == 0.0
        z = z.float() * 1e-8
        log_a_prob = torch.log(a_prob + z)
        return a_prob, log_a_prob, sample_a, max_a

    # Set up function for computing SAC Q-losses
    def compute_loss_q(data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']

        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a_prob, log_a_prob, sample_a, max_a = get_actions_info(ac.pi(o2))

            # Target Q-values
            q1_pi_targ = ac_targ.q1(o2)
            q2_pi_targ = ac_targ.q2(o2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma * (1 - d) * (a_prob * (q_pi_targ - alpha * log_a_prob)).sum(dim=1)

        # MSE loss against Bellman backup
        q1 = ac.q1(o).gather(1, a.unsqueeze(-1).long())
        q2 = ac.q2(o).gather(1, a.unsqueeze(-1).long())
        loss_q1 = F.mse_loss(q1, backup.unsqueeze(-1))
        loss_q2 = F.mse_loss(q2, backup.unsqueeze(-1))
        loss_q = loss_q1 + loss_q2

        # Useful info for logging
        q_info = dict(Q1Vals=q1.detach().numpy(),
                      Q2Vals=q2.detach().numpy())

        return loss_q, q_info

    # Set up function for computing SAC pi loss
    def compute_loss_pi(data):
        o = data['obs']
        a_prob, log_a_prob, sample_a, max_a = get_actions_info(ac.pi(o))
        q1_pi = ac.q1(o)
        q2_pi = ac.q2(o)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = (a_prob * (alpha * log_a_prob - q_pi)).mean()
        entropy = torch.sum(log_a_prob * a_prob, dim=1).detach()

        # Useful info for logging
        pi_info = dict(LogPi=entropy.numpy())
        return loss_pi, pi_info

    def update(data):
        # First run one gradient descent step for Q1 and Q2
        q1_optimizer.zero_grad()
        q2_optimizer.zero_grad()
        loss_q, q_info = compute_loss_q(data)
        loss_q.backward()
        q1_optimizer.step()
        q2_optimizer.step()

        # Record things
        logger.store(LossQ=loss_q.detach().item(), **q_info)

        # Freeze Q-networks so you don't waste computational effort
        # computing gradients for them during the policy learning step.
        # for p in q_params:
        #     p.requires_grad = False

        # Next run one gradient descent step for pi.
        pi_optimizer.zero_grad()
        loss_pi, pi_info = compute_loss_pi(data)
        loss_pi.backward()
        pi_optimizer.step()

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        # for p in q_params:
            # p.requires_grad = True

        # Record things
        logger.store(LossPi=loss_pi.item(), **pi_info)

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.copy_((1 - polyak) * p.data + polyak * p_targ.data)
                # p_targ.data.mul_(polyak)
                # p_targ.data.add_((1 - polyak) * p.data)

    def get_action(o, greedy=False):
        if len(o.shape) == 1:
            o = np.expand_dims(o, axis=0)
        a_prob = ac.act(torch.as_tensor(o, dtype=torch.float32), greedy)
        a_prob, log_a_prob, sample_a, max_a = get_actions_info(a_prob)
        action = sample_a if not greedy else max_a
        return action.item()

    def test_agent():
        for j in range(num_test_episodes):
            o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
            while not (d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, _ = test_env.step(get_action(o, True))
                # test_env.render()
                ep_ret += r
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len)
            print(ep_ret)

    # Prepare for interaction with environment
    total_steps = steps_per_epoch * epochs
    start_time = time.time()
    scores = []
    o, ep_ret, ep_len = env.reset(), 0, 0
    discard = False
    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):

        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy.
        if t > start_steps:
            a = get_action(o)
        else:
            a = env.action_space.sample()

        # Step the env
        o2, r, d, info = env.step(a)
        if info.get('no_data_receive', False):
            discard = True
        env.render()
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len or discard else d

        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len) or discard:
            logger.store(EpRet=ep_ret, EpLen=ep_len)
            scores.append(ep_ret)
            print("round len:{}, round score: {}, mean score: {}".format(ep_len, ep_ret, np.mean(scores[-100:])))
            o, ep_ret, ep_len = env.reset(), 0, 0
            discard = False


        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = replay_buffer.sample_batch(batch_size)
                update(data=batch)

        # End of epoch handling
        if (t + 1) % steps_per_epoch == 0:
            epoch = (t + 1) // steps_per_epoch

            # Save model
            if t % save_freq == 0 and t > 0:
                torch.save(ac.state_dict(), os.path.join(save_dir, "model"))
                state_dict_trans(ac.state_dict(), os.path.join(save_dir, "SAC_Toothless.numpy"))
                print("Saving model at episode:{}".format(t))
            if (epoch % save_freq == 0) or (epoch == epochs):
                logger.save_state({'env': env}, None)