コード例 #1
0
def cdim_div(rews, obs, acts):

    if obs.shape[0] == 1000:
        gait_start = 200
    else:
        gait_start = 0

    _,c,_,_ = mesh_dim(obs[gait_start:])
    c = np.clip(c, 0, len(rews))
    return rews/c
コード例 #2
0
    def learn(self, n_epochs, verbose=True):
        torch.set_grad_enabled(False)
        proc_list = []
        master_q_list = []
        worker_q_list = []
        learn_start_idx = copy.copy(self.total_epochs)

        if self.step_schedule:
            step_lookup = make_schedule(self.step_schedule, n_epochs)

        if self.exp_schedule:
            exp_lookup = make_schedule(self.exp_schedule, n_epochs)

        for i in range(self.n_workers):
            master_q = Queue()
            worker_q = Queue()

            proc = Process(target=worker_fn,
                           args=(worker_q, master_q, self.algo, self.env_name,
                                 self.postprocessor, self.get_trainable,
                                 self.seed))
            proc.start()
            proc_list.append(proc)
            master_q_list.append(master_q)
            worker_q_list.append(worker_q)

        n_param = self.W_flat.shape[0]

        rng = default_rng()

        for epoch in range(n_epochs):
            if self.step_schedule:
                self.step_size = step_lookup(epoch)
            if self.exp_schedule:
                self.exp_noise = exp_lookup(epoch)

            if len(self.raw_rew_hist) > 2 and self.reward_stop:
                if self.raw_rew_hist[
                        -1] >= self.reward_stop and self.raw_rew_hist[
                            -2] >= self.reward_stop:
                    early_stop = True
                    break

            deltas = rng.standard_normal((self.n_delta, n_param),
                                         dtype=np.float32)
            #import ipdb; ipdb.set_trace()
            pm_W = np.concatenate((self.W_flat + (deltas * self.exp_noise),
                                   self.W_flat - (deltas * self.exp_noise)))

            start = time.time()
            seeds = np.random.randint(1, 2**32 - 1, self.n_delta)

            for i, Ws in enumerate(pm_W):
                # if self.epoch_seed:
                #     epoch_seed = i%self.n_delta
                # else:
                #     epoch_seed = None

                epoch_seed = int(seeds[i % self.n_delta])
                #epoch_seed = None

                master_q_list[i % self.n_workers].put((Ws, False, epoch_seed))

            results = []
            for i, _ in enumerate(pm_W):
                results.append(worker_q_list[i % self.n_workers].get())

            end = time.time()
            t = (end - start)

            p_returns = []
            m_returns = []
            l_returns = []
            top_returns = []

            for p_result, m_result in zip(results[:self.n_delta],
                                          results[self.n_delta:]):
                pr, plr = p_result
                mr, mlr = m_result

                p_returns.append(pr)
                m_returns.append(mr)
                l_returns.append(plr)
                l_returns.append(mlr)
                top_returns.append(max(pr, mr))

            top_idx = sorted(range(len(top_returns)),
                             key=lambda k: top_returns[k],
                             reverse=True)[:self.n_top]
            p_returns = np.stack(p_returns).astype(np.float32)[top_idx]
            m_returns = np.stack(m_returns).astype(np.float32)[top_idx]
            l_returns = np.stack(l_returns).astype(np.float32)[top_idx]

            self.raw_rew_hist.append(np.stack(top_returns)[top_idx].mean())
            self.r_hist.append((p_returns.mean() + m_returns.mean()) / 2)

            if verbose and epoch % 10 == 0:
                from seagul.zoo3_utils import do_rollout_stable
                env, model = load_zoo_agent(self.env_name, self.algo)
                torch.nn.utils.vector_to_parameters(
                    torch.tensor(self.W_flat, requires_grad=False),
                    self.get_trainable(self.model))
                o, a, r, info = do_rollout_stable(env, self.model)
                if type(o[0]) == collections.OrderedDict:
                    o, _, _ = dict_to_array(o)

                #                o_mdim  = o[200:]
                o_mdim = o
                try:
                    mdim, cdim, _, _ = mesh_dim(o_mdim)
                except:
                    mdim = np.nan
                    cdim = np.nan

                print(
                    f"{epoch} : mean return: {self.raw_rew_hist[-1]}, top_return: {np.stack(top_returns)[top_idx][0]}, mdim: {mdim}, cdim: {cdim}, eps:{self.n_delta*2/t}"
                )

            self.total_epochs += 1
            self.W_flat = self.W_flat + (
                self.step_size / (self.n_delta * np.concatenate(
                    (p_returns, m_returns)).std() + 1e-6)) * np.sum(
                        (p_returns - m_returns) * deltas[top_idx].T, axis=1)

        for q in master_q_list:
            q.put((None, True, None))

        for proc in proc_list:
            proc.join()

        torch.nn.utils.vector_to_parameters(
            torch.tensor(self.W_flat, requires_grad=False),
            self.get_trainable(self.model))

        torch.set_grad_enabled(True)
        return self.model, self.raw_rew_hist[learn_start_idx:], locals()
コード例 #3
0
def cdim_mul(rews, obs, acts):
    _,c,_,_ = mesh_dim(obs)
    return c*rews
コード例 #4
0
def mdim_div(rews, obs, acts):
    m,_,_,_ = mesh_dim(obs)
    return rews/m
コード例 #5
0
def mdim_mul(rews, obs, acts):
    m,_,_,_ = mesh_dim(obs)
    return m*rews
コード例 #6
0
ファイル: run_all.py プロジェクト: sgillen/lorenz
def cdim_div(rews, obs, acts):
    _,c,_,_ = mesh_dim(obs)
    return rews/c
コード例 #7
0
ファイル: mesh.py プロジェクト: ntalele/seagul
    seed = 2
    ep_length = 10000
    policy = policy_dict['identity'][seed]
    env = gym.make(env_name)
    o, a, r, _ = do_long_rollout(env, policy, ep_length=ep_length)
    # o,a,r,l = do_rollout(env, policy, render=True)

    plt.plot(o)
    plt.show()

    target = o[200:]
    target = (target - policy.state_means) / policy.state_std
    # target = (target - target.mean(dim=0))/target.std(dim=0)

    print(sum(r))
    m, c, l, d = mesh_dim(target)
    print(m)
    print(c)

    # ==============
    policy = policy_dict['mdim_div'][seed]
    #% time
    o2, a2, r2, _ = do_long_rollout(env, policy, ep_length=ep_length)
    # o2,a2,r2,l2 = do_rollout(env, policy, render=True)
    plt.plot(o2)
    plt.figure()

    target2 = o2[200:]
    target2 = (target2 - policy.state_means) / policy.state_std
    # target2 = (target2 - target2.mean(dim=0))/target2.std(dim=0)