Beispiel #1
0
    def learn(self, n_epochs, verbose=True):
        proc_list = []
        master_q_list = []
        worker_q_list = []
        learn_start_idx = copy.copy(self.total_epochs)

        if self.step_schedule:
            step_lookup = make_schedule(self.step_schedule, n_epochs)

        if self.exp_schedule:
            exp_lookup = make_schedule(self.exp_schedule, n_epochs)

        for i in range(self.n_workers):
            master_q = Queue()
            worker_q = Queue()
            proc = Process(target=worker_fn,
                           args=(worker_q, master_q, self.model_list[0],
                                 self.env_name, self.env_config,
                                 self.postprocessor, self.seed))
            proc.start()
            proc_list.append(proc)
            master_q_list.append(master_q)
            worker_q_list.append(worker_q)

        n_param = self.W_flat_list[0].shape[0]

        rng = default_rng()

        for epoch in range(n_epochs):
            if self.step_schedule:
                self.step_size = step_lookup(epoch)
            if self.exp_schedule:
                self.exp_noise = exp_lookup(epoch)

            if len(self.raw_rew_hist) > 2 and self.reward_stop:
                if self.raw_rew_hist[
                        -1] >= self.reward_stop and self.raw_rew_hist[
                            -2] >= self.reward_stop:
                    early_stop = True
                    break

            seeds = rng.integers(2**32, size=self.n_delta)
            delta_list = []
            top_returns_list = []
            m_returns_list = []
            p_returns_list = []
            states_list = []

            with torch.no_grad():
                for model_i, W_flat in enumerate(self.W_flat_list):
                    deltas = rng.standard_normal((self.n_delta, n_param))
                    delta_list.append(deltas)

                    W_plus_delta = np.concatenate(
                        (W_flat + (deltas * self.exp_noise),
                         W_flat - (deltas * self.exp_noise)))
                    seeds = np.concatenate([seeds, seeds])

                    start = time.time()

                    for i, Ws in enumerate(W_plus_delta):
                        master_q_list[i % self.n_workers].put(
                            (Ws, self.state_mean_list[model_i],
                             self.state_std_list[model_i], seeds[i]))

                    results = []
                    for i, _ in enumerate(W_plus_delta):
                        results.append(worker_q_list[i % self.n_workers].get())

                    end = time.time()
                    t = (end - start)

                    states = []
                    p_returns = []
                    m_returns = []
                    top_returns = []

                    for p_result, m_result in zip(results[:self.n_delta],
                                                  results[self.n_delta:]):
                        ps, pr, plr = p_result
                        ms, mr, mlr = m_result

                        p_returns.append(pr)
                        m_returns.append(mr)

                        top_returns.append(max([pr, mr]))
                        top_states = [ps, ms][np.argmax([pr, mr]).item()]
                        states.append(top_states)

                    states_list.append(states)
                    top_returns_list.append(top_returns)
                    m_returns_list.append(m_returns)
                    p_returns_list.append(p_returns)

                    concat_states = np.concatenate(states)
                    self.state_mean_list[model_i] = update_mean(
                        concat_states, self.state_mean_list[model_i],
                        self.total_steps_list[model_i])
                    self.state_std_list[model_i] = update_std(
                        concat_states, self.state_std_list[model_i],
                        self.total_steps_list[model_i])

                    ep_steps = concat_states.shape[0]
                    self.total_steps_list[model_i] += ep_steps

            # Classifier Update ======================================================================================
            T = np.array(top_returns_list)
            Y = np.argmax(T, axis=0)

            Ytrain = []
            Xtrain = []

            for i, y in enumerate(Y):
                # print(f"states_lists[0][{i}][0] = {states_list[0][i][0]}")
                # print(f"states_lists[1][{i}][0] = {states_list[1][i][0]}")
                for x in states_list[y][i]:
                    Xtrain.append(x)
                    Ytrain.append(y)

            Xtrain = np.array(Xtrain, dtype=np.float32)
            Ytrain = np.array(Ytrain)

            print(Xtrain.shape)
            print(Ytrain.shape)

            loss_hist = fit_model(self.classifier,
                                  Xtrain,
                                  Ytrain,
                                  5,
                                  batch_size=64,
                                  loss_fn=torch.nn.CrossEntropyLoss())

            # ARS Update  ============================================================================================
            with torch.no_grad():
                train_return_list = [[] for _ in range(T.shape[0])]
                train_m_list = [[] for _ in range(T.shape[0])]
                train_p_list = [[] for _ in range(T.shape[0])]
                train_delta_list = [[] for _ in range(T.shape[0])]

                for i, y in enumerate(Y):
                    train_return_list[y].append(top_returns_list[y][i])
                    train_p_list[y].append(m_returns_list[y][i])
                    train_m_list[y].append(p_returns_list[y][i])
                    train_delta_list[y].append(delta_list[y][i])

                for i, _ in enumerate(self.W_flat_list):
                    top_returns = train_return_list[i]
                    p_returns = train_p_list[i]
                    m_returns = train_m_list[i]
                    deltas = np.array(train_delta_list[i])

                    if len(top_returns) == 0:
                        continue

                    top_idx = sorted(range(len(top_returns)),
                                     key=lambda k: top_returns[k],
                                     reverse=True)[:self.n_top]
                    p_returns = np.stack(p_returns)[top_idx]
                    m_returns = np.stack(m_returns)[top_idx]
                    #print(f"{i} : {self.model_list[i].policy.state_dict()}")
                    #print(f" {i} : {self.W_flat_list[i]}")
                    self.W_flat_list[i] = self.W_flat_list[i] + (
                        self.step_size / (self.n_delta * np.concatenate(
                            (p_returns, m_returns)).std() + 1e-6)) * np.sum(
                                (p_returns - m_returns) * deltas[top_idx].T,
                                axis=1)
                    #print(f"{i} : {self.model_list[i].policy.state_dict()}")
                    print(f" {i} : {self.W_flat_list[i]}")
                # if verbose and epoch % 10 == 0:
                #     print(f"{epoch} : mean return: {l_returns.mean()}, top_return: {l_returns.max()}, fps:{states.shape[0]/t}")

                # self.raw_rew_hist.append(np.stack(top_returns)[top_idx].mean())
                # self.r_hist.append((p_returns.mean() + m_returns.mean())/2)

                self.total_epochs += 1

        for q in master_q_list:
            q.put("STOP")
        for proc in proc_list:
            proc.join()

        #print(f" model 0 state dict before: {self.model_list[0].policy.state_dict()}")
        for i, _ in enumerate(self.model_list):
            print(f" model {i} w_flat: {self.W_flat_list[i]}")
            torch.nn.utils.vector_to_parameters(
                torch.tensor(self.W_flat_list[i]),
                self.model_list[i].policy.parameters())

        self.model = ARSSwitchingModel(self.model_list, self.classifier)

        # self.model.policy.state_means = torch.from_numpy(self.state_mean)
        # self.model.policy.state_std = torch.from_numpy(self.state_std)
        #
        # torch.set_grad_enabled(True)
        return self.model, self.raw_rew_hist[learn_start_idx:], locals()
Beispiel #2
0
    def learn(self, train_steps):
        """
                runs sac for train_steps

                Returns:
                    model: trained model
                    avg_reward_hist: list with the average reward per episode at each epoch
                    var_dict: dictionary with all locals, for logging/debugging purposes
                """

        torch.set_num_threads(1) # performance issue with data loader

        env = gym.make(self.env_name, **self.env_config)
        if isinstance(env.action_space, gym.spaces.Box):
            act_size = env.action_space.shape[0]
            act_dtype = env.action_space.sample().dtype
        else:
            raise NotImplementedError("trying to use unsupported action space", env.action_space)

        obs_size = env.observation_space.shape[0]

        random_model = RandModel(self.model.act_limit, act_size)
        self.replay_buf = ReplayBuffer(obs_size, act_size, self.replay_buf_size)
        self.target_value_fn = copy.deepcopy(self.model.value_fn)

        pol_opt = torch.optim.Adam(self.model.policy.parameters(), lr=self.sgd_lr)
        val_opt = torch.optim.Adam(self.model.value_fn.parameters(), lr=self.sgd_lr)
        q1_opt = torch.optim.Adam(self.model.q1_fn.parameters(), lr=self.sgd_lr)
        q2_opt = torch.optim.Adam(self.model.q2_fn.parameters(), lr=self.sgd_lr)

        if self.sgd_lr_sched:
            sgd_lookup = make_schedule(self.sgd_lr_sched, train_steps)
        else:
            sgd_lookup = None


        # seed all our RNGs
        env.seed(self.seed)
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)

        # set defaults, and decide if we are using a GPU or not
        use_cuda = torch.cuda.is_available() and self.use_gpu
        device = torch.device("cuda:0" if use_cuda else "cpu")

        self.raw_rew_hist = []
        self.val_loss_hist = []
        self.pol_loss_hist = []
        self.q1_loss_hist = []
        self.q2_loss_hist = []

        progress_bar = tqdm.tqdm(total=train_steps + self.normalize_steps)
        cur_total_steps = 0
        progress_bar.update(0)
        early_stop = False
        norm_obs1 = torch.empty(0)

        while cur_total_steps < self.normalize_steps:
            ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, self.env_max_steps)
            norm_obs1 = torch.cat((norm_obs1, ep_obs1))

            ep_steps = ep_rews.shape[0]
            cur_total_steps += ep_steps

            progress_bar.update(ep_steps)
        if self.normalize_steps > 0:
            obs_mean = norm_obs1.mean(axis=0)
            obs_std  = norm_obs1.std(axis=0)
            obs_std[torch.isinf(1/obs_std)] = 1

            self.model.policy.state_means = obs_mean
            self.model.policy.state_std  =  obs_std
            self.model.value_fn.state_means = obs_mean
            self.model.value_fn.state_std = obs_std
            self.target_value_fn.state_means = obs_mean
            self.target_value_fn.state_std = obs_std

            self.model.q1_fn.state_means = torch.cat((obs_mean, torch.zeros(act_size)))
            self.model.q1_fn.state_std = torch.cat((obs_std, torch.ones(act_size)))
            self.model.q2_fn.state_means = self.model.q1_fn.state_means
            self.model.q2_fn.state_std = self.model.q1_fn.state_std

        while cur_total_steps < self.exploration_steps:
            ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, self.env_max_steps)
            self.replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

            ep_steps = ep_rews.shape[0]
            cur_total_steps += ep_steps

            progress_bar.update(ep_steps)

        while cur_total_steps < train_steps:
            cur_batch_steps = 0

            # Bail out if we have met out reward threshold
            if len(self.raw_rew_hist) > 2 and self.reward_stop:
                if self.raw_rew_hist[-1] >= self.reward_stop and self.raw_rew_hist[-2] >= self.reward_stop:
                    early_stop = True
                    break

            # collect data with the current policy
            # ========================================================================
            while cur_batch_steps < self.min_steps_per_update:
                ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, self.model, self.env_max_steps)
                self.replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

                ep_steps = ep_rews.shape[0]
                cur_batch_steps += ep_steps
                cur_total_steps += ep_steps

                self.raw_rew_hist.append(torch.sum(ep_rews))
                #print(self.raw_rew_hist[-1])


            progress_bar.update(cur_batch_steps)

            for _ in range(min(int(ep_steps), self.iters_per_update)):

                torch.autograd.set_grad_enabled(False)
                # compute targets for Q and V
                # ========================================================================
                replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = self.replay_buf.sample_batch(self.replay_batch_size)

                q_targ = replay_rews + self.gamma * (1 - replay_done) * self.target_value_fn(replay_obs2)

                noise = torch.randn(self.replay_batch_size, act_size)
                sample_acts, sample_logp = self.model.select_action(replay_obs1, noise)

                q_in = torch.cat((replay_obs1, sample_acts), dim=1)
                q_preds = torch.cat((self.model.q1_fn(q_in), self.model.q2_fn(q_in)), dim=1)
                q_min, q_min_idx = torch.min(q_preds, dim=1)
                q_min = q_min.reshape(-1,1)

                v_targ = q_min - self.alpha * sample_logp
                #v_targ = v_targ

                torch.autograd.set_grad_enabled(True)

                # q_fn update
                # ========================================================================
                num_mbatch = int(self.replay_batch_size / self.sgd_batch_size)

                for i in range(num_mbatch):
                    cur_sample = i*self.sgd_batch_size

                    q_in = torch.cat((replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], replay_acts[cur_sample:cur_sample + self.sgd_batch_size]), dim=1)
                    q1_preds = self.model.q1_fn(q_in)
                    q2_preds = self.model.q2_fn(q_in)
                    q1_loss = torch.pow(q1_preds - q_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean()
                    q2_loss = torch.pow(q2_preds - q_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean()
                    q_loss = q1_loss + q2_loss

                    q1_opt.zero_grad()
                    q2_opt.zero_grad()
                    q_loss.backward()
                    q1_opt.step()
                    q2_opt.step()

                # val_fn update
                # ========================================================================
                for i in range(num_mbatch):
                    cur_sample = i*self.sgd_batch_size

                    # predict and calculate loss for the batch
                    val_preds = self.model.value_fn(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size])
                    val_loss = torch.pow(val_preds - v_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean()

                    # do the normal pytorch update
                    val_opt.zero_grad()
                    val_loss.backward()
                    val_opt.step()

                # policy_fn update
                # ========================================================================
                for param in self.model.q1_fn.parameters():
                    param.requires_grad = False

                for i in range(num_mbatch):
                    cur_sample = i*self.sgd_batch_size

                    noise = torch.randn(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size].shape[0], act_size)
                    local_acts, local_logp = self.model.select_action(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], noise)

                    q_in = torch.cat((replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], local_acts), dim=1)
                    pol_loss = (self.alpha * local_logp - self.model.q1_fn(q_in)).mean()

                    pol_opt.zero_grad()
                    pol_loss.backward()
                    pol_opt.step()

                for param in self.model.q1_fn.parameters():
                    param.requires_grad = True

                # Update target value fn with polyak average
                # ========================================================================
                self.val_loss_hist.append(val_loss.item())
                self.pol_loss_hist.append(pol_loss.item())
                self.q1_loss_hist.append(q1_loss.item())
                self.q2_loss_hist.append(q2_loss.item())

                val_sd = self.model.value_fn.state_dict()
                tar_sd = self.target_value_fn.state_dict()
                for layer in tar_sd:
                    tar_sd[layer] = self.polyak * tar_sd[layer] + (1 - self.polyak) * val_sd[layer]

                self.target_value_fn.load_state_dict(tar_sd)


            #Update LRs
            if sgd_lookup:
                pol_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps)
                val_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps)
                q1_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps)
                q2_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps)

        return self.model, self.raw_rew_hist, locals()
Beispiel #3
0
    def learn(self, n_epochs, verbose=True):
        torch.set_grad_enabled(False)
        proc_list = []
        master_q_list = []
        worker_q_list = []
        learn_start_idx = copy.copy(self.total_epochs)

        if self.step_schedule:
            step_lookup = make_schedule(self.step_schedule, n_epochs)

        if self.exp_schedule:
            exp_lookup = make_schedule(self.exp_schedule, n_epochs)

        for i in range(self.n_workers):
            master_q = Queue()
            worker_q = Queue()

            proc = Process(target=worker_fn,
                           args=(worker_q, master_q, self.algo, self.env_name,
                                 self.postprocessor, self.get_trainable,
                                 self.seed))
            proc.start()
            proc_list.append(proc)
            master_q_list.append(master_q)
            worker_q_list.append(worker_q)

        n_param = self.W_flat.shape[0]

        rng = default_rng()

        for epoch in range(n_epochs):
            if self.step_schedule:
                self.step_size = step_lookup(epoch)
            if self.exp_schedule:
                self.exp_noise = exp_lookup(epoch)

            if len(self.raw_rew_hist) > 2 and self.reward_stop:
                if self.raw_rew_hist[
                        -1] >= self.reward_stop and self.raw_rew_hist[
                            -2] >= self.reward_stop:
                    early_stop = True
                    break

            deltas = rng.standard_normal((self.n_delta, n_param),
                                         dtype=np.float32)
            #import ipdb; ipdb.set_trace()
            pm_W = np.concatenate((self.W_flat + (deltas * self.exp_noise),
                                   self.W_flat - (deltas * self.exp_noise)))

            start = time.time()
            seeds = np.random.randint(1, 2**32 - 1, self.n_delta)

            for i, Ws in enumerate(pm_W):
                # if self.epoch_seed:
                #     epoch_seed = i%self.n_delta
                # else:
                #     epoch_seed = None

                epoch_seed = int(seeds[i % self.n_delta])
                #epoch_seed = None

                master_q_list[i % self.n_workers].put((Ws, False, epoch_seed))

            results = []
            for i, _ in enumerate(pm_W):
                results.append(worker_q_list[i % self.n_workers].get())

            end = time.time()
            t = (end - start)

            p_returns = []
            m_returns = []
            l_returns = []
            top_returns = []

            for p_result, m_result in zip(results[:self.n_delta],
                                          results[self.n_delta:]):
                pr, plr = p_result
                mr, mlr = m_result

                p_returns.append(pr)
                m_returns.append(mr)
                l_returns.append(plr)
                l_returns.append(mlr)
                top_returns.append(max(pr, mr))

            top_idx = sorted(range(len(top_returns)),
                             key=lambda k: top_returns[k],
                             reverse=True)[:self.n_top]
            p_returns = np.stack(p_returns).astype(np.float32)[top_idx]
            m_returns = np.stack(m_returns).astype(np.float32)[top_idx]
            l_returns = np.stack(l_returns).astype(np.float32)[top_idx]

            self.raw_rew_hist.append(np.stack(top_returns)[top_idx].mean())
            self.r_hist.append((p_returns.mean() + m_returns.mean()) / 2)

            if verbose and epoch % 10 == 0:
                from seagul.zoo3_utils import do_rollout_stable
                env, model = load_zoo_agent(self.env_name, self.algo)
                torch.nn.utils.vector_to_parameters(
                    torch.tensor(self.W_flat, requires_grad=False),
                    self.get_trainable(self.model))
                o, a, r, info = do_rollout_stable(env, self.model)
                if type(o[0]) == collections.OrderedDict:
                    o, _, _ = dict_to_array(o)

                #                o_mdim  = o[200:]
                o_mdim = o
                try:
                    mdim, cdim, _, _ = mesh_dim(o_mdim)
                except:
                    mdim = np.nan
                    cdim = np.nan

                print(
                    f"{epoch} : mean return: {self.raw_rew_hist[-1]}, top_return: {np.stack(top_returns)[top_idx][0]}, mdim: {mdim}, cdim: {cdim}, eps:{self.n_delta*2/t}"
                )

            self.total_epochs += 1
            self.W_flat = self.W_flat + (
                self.step_size / (self.n_delta * np.concatenate(
                    (p_returns, m_returns)).std() + 1e-6)) * np.sum(
                        (p_returns - m_returns) * deltas[top_idx].T, axis=1)

        for q in master_q_list:
            q.put((None, True, None))

        for proc in proc_list:
            proc.join()

        torch.nn.utils.vector_to_parameters(
            torch.tensor(self.W_flat, requires_grad=False),
            self.get_trainable(self.model))

        torch.set_grad_enabled(True)
        return self.model, self.raw_rew_hist[learn_start_idx:], locals()
Beispiel #4
0
def td3(
        env_name,
        train_steps,
        model,
        env_max_steps=0,
        min_steps_per_update=1,
        iters_per_update=200,
        replay_batch_size=64,
        seed=0,
        act_std_schedule=(.1,),
        gamma=0.95,
        polyak=0.995,
        sgd_batch_size=64,
        sgd_lr=3e-4,
        exploration_steps=1000,
        replay_buf_size=int(100000),
        reward_stop=None,
        env_config=None
):
    # Initialize env, and other globals
    # ========================================================================
    if env_config is None:
        env_config = {}
    env = gym.make(env_name, **env_config)
    if isinstance(env.action_space, gym.spaces.Box):
        act_size = env.action_space.shape[0]
        act_dtype = env.action_space.sample().dtype
    else:
        raise NotImplementedError("trying to use unsupported action space", env.action_space)

    obs_size = env.observation_space.shape[0]

    # seed all our RNGs
    env.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    random_model = RandModel(model.act_limit, act_size)
    replay_buf = ReplayBuffer(obs_size, act_size, replay_buf_size)
    target_q1_fn = dill.loads(dill.dumps(model.q1_fn))
    target_q2_fn = dill.loads(dill.dumps(model.q2_fn))
    target_policy = dill.loads(dill.dumps(model.policy))

    for param in target_q1_fn.parameters():
        param.requires_grad = False

    for param in target_q2_fn.parameters():
        param.requires_grad = False

    for param in target_policy.parameters():
        param.requires_grad = False

    act_std_lookup = make_schedule(act_std_schedule, train_steps)
    act_std = act_std_lookup(0)

    pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr)
    q1_opt = torch.optim.Adam(model.q1_fn.parameters(), lr=sgd_lr)
    q2_opt = torch.optim.Adam(model.q2_fn.parameters(), lr=sgd_lr)

    progress_bar = tqdm.tqdm(total=train_steps)
    cur_total_steps = 0
    progress_bar.update(0)
    early_stop = False

    raw_rew_hist = []
    pol_loss_hist = []
    q1_loss_hist = []
    q2_loss_hist = []

    # Fill the replay buffer with actions taken from a random model
    # ========================================================================
    while cur_total_steps < exploration_steps:
        ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, env_max_steps, act_std)
        replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

        ep_steps = ep_rews.shape[0]
        cur_total_steps += ep_steps

        progress_bar.update(ep_steps)

    # Keep training until we take train_step environment steps
    # ========================================================================
    while cur_total_steps < train_steps:
        cur_batch_steps = 0

        # Bail out if we have met out reward threshold
        if len(raw_rew_hist) > 2 and reward_stop:
            print(raw_rew_hist[-1])
            if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop:
                early_stop = True
                break

        # collect data with the current policy
        # ========================================================================
        while cur_batch_steps < min_steps_per_update:
            ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, model, env_max_steps, act_std)
            replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done)

            ep_steps = ep_rews.shape[0]
            cur_batch_steps += ep_steps
            cur_total_steps += ep_steps

            raw_rew_hist.append(torch.sum(ep_rews))

        progress_bar.update(cur_batch_steps)

        # Do the update
        # ========================================================================
        for _ in range(min(int(ep_steps), iters_per_update)):

            # Compute target Q
            replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = replay_buf.sample_batch(replay_batch_size)

            with torch.no_grad():
                acts_from_target = target_policy(replay_obs2)
                q_in = torch.cat((replay_obs2, acts_from_target), dim=1)
                q_targ = replay_rews + gamma*(1 - replay_done)*target_q1_fn(q_in)

            num_mbatch = int(replay_batch_size / sgd_batch_size)

            # q_fn update
            # ========================================================================
            for i in range(num_mbatch):
                cur_sample = i * sgd_batch_size

                q_in_local = torch.cat((replay_obs1[cur_sample:cur_sample + sgd_batch_size], replay_acts[cur_sample:cur_sample + sgd_batch_size]), dim=1)
                local_qtarg = q_targ[cur_sample:cur_sample + sgd_batch_size]

                q1_loss = ((model.q1_fn(q_in_local) - local_qtarg)**2).mean()

                #q2_preds = model.q2_fn(q_in)
                #q2_loss = (q2_preds - q_targ[cur_sample:cur_sample + sgd_batch_size]**2).mean()
                q_loss = q1_loss# + q2_loss

                q1_opt.zero_grad()
                #q2_opt.zero_grad()
                q_loss.backward()
                q1_opt.step()
                #q2_opt.step()

            # policy_fn update
            # ========================================================================
            for param in model.q1_fn.parameters():
                param.requires_grad = False

            for i in range(num_mbatch):
                cur_sample = i * sgd_batch_size
                local_obs = replay_obs1[cur_sample:cur_sample + sgd_batch_size]
                local_acts = model.policy(local_obs)
                q_in = torch.cat((local_obs, local_acts), dim=1)

                pol_loss = -(model.q1_fn(q_in).mean())

                pol_opt.zero_grad()
                pol_loss.backward()
                pol_opt.step()

            for param in model.q1_fn.parameters():
                param.requires_grad = True

            # Update target value fn with polyak average
            # ========================================================================
            pol_loss_hist.append(pol_loss.item())
            q1_loss_hist.append(q1_loss.item())
            #q2_loss_hist.append(q2_loss.item())

            target_q1_fn = update_target_fn(model.q1_fn, target_q1_fn, polyak)
            target_q2_fn = update_target_fn(model.q2_fn, target_q2_fn, polyak)
            target_policy = update_target_fn(model.policy, target_policy, polyak)
            act_std = act_std_lookup(cur_total_steps)

    return model, raw_rew_hist, locals()
Beispiel #5
0
    def learn(self, n_epochs, verbose=True):
        torch.set_grad_enabled(False)
        proc_list = []
        master_q_list = []
        worker_q_list = []
        learn_start_idx = copy.copy(self.total_epochs)

        if self.step_schedule:
            step_lookup = make_schedule(self.step_schedule, n_epochs)

        if self.exp_schedule:
            exp_lookup = make_schedule(self.exp_schedule, n_epochs)

        for i in range(self.n_workers):
            master_q = Queue()
            worker_q = Queue()
            proc = Process(target=worker_fn, args=(worker_q, master_q, self.model, self.env_name, self.env_config, self.postprocessor, self.seed))
            proc.start()
            proc_list.append(proc)
            master_q_list.append(master_q)
            worker_q_list.append(worker_q)

        n_param = self.W_flat.shape[0]

        rng = default_rng()         

        for epoch in range(n_epochs):
            if self.step_schedule:
                self.step_size = step_lookup(epoch)
            if self.exp_schedule:
                self.exp_noise = exp_lookup(epoch)

            if len(self.raw_rew_hist) > 2 and self.reward_stop:
                if self.raw_rew_hist[-1] >= self.reward_stop and self.raw_rew_hist[-2] >= self.reward_stop:
                    early_stop = True
                    break
            
            deltas = rng.standard_normal((self.n_delta, n_param))
            #import ipdb; ipdb.set_trace()
            pm_W = np.concatenate((self.W_flat+(deltas*self.exp_noise), self.W_flat-(deltas*self.exp_noise)))

            start = time.time()

            for i,Ws in enumerate(pm_W):
                master_q_list[i % self.n_workers].put((Ws ,self.state_mean,self.state_std))
                
            results = []
            for i, _ in enumerate(pm_W):
                results.append(worker_q_list[i % self.n_workers].get())

            end = time.time()
            t = (end - start)
                
            states = np.array([]).reshape(0,self.obs_size)
            p_returns = []
            m_returns = []
            l_returns = []
            top_returns = []

            for p_result, m_result in zip(results[:self.n_delta], results[self.n_delta:]):
                ps, pr, plr = p_result
                ms, mr, mlr = m_result

                states = np.concatenate((states, ms, ps), axis=0)
                p_returns.append(pr)
                m_returns.append(mr)
                l_returns.append(plr); l_returns.append(mlr)
                top_returns.append(max(pr,mr))

            top_idx = sorted(range(len(top_returns)), key=lambda k: top_returns[k], reverse=True)[:self.n_top]
            p_returns = np.stack(p_returns)[top_idx]
            m_returns = np.stack(m_returns)[top_idx]
            l_returns = np.stack(l_returns)[top_idx]

            if verbose and epoch % 10 == 0:
                print(f"{epoch} : mean return: {l_returns.mean()}, top_return: {l_returns.max()}, fps:{states.shape[0]/t}")

            self.raw_rew_hist.append(np.stack(top_returns)[top_idx].mean())
            self.r_hist.append((p_returns.mean() + m_returns.mean())/2)

            ep_steps = states.shape[0]
            self.state_mean = update_mean(states, self.state_mean, self.total_steps)
            self.state_std = update_std(states, self.state_std, self.total_steps)

            self.total_steps += ep_steps
            self.total_epochs += 1

            self.W_flat = self.W_flat + (self.step_size / (self.n_delta * np.concatenate((p_returns, m_returns)).std() + 1e-6)) * np.sum((p_returns - m_returns)*deltas[top_idx].T, axis=1)


        for q in master_q_list:
            q.put("STOP")
        for proc in proc_list:
            proc.join()

        torch.nn.utils.vector_to_parameters(torch.tensor(self.W_flat), self.model.policy.parameters())

        self.model.policy.state_means = torch.from_numpy(self.state_mean)
        self.model.policy.state_std = torch.from_numpy(self.state_std)

        torch.set_grad_enabled(True)
        return self.model, self.raw_rew_hist[learn_start_idx:], locals()
Beispiel #6
0
    def learn(self, total_steps):
        """
        The actual training loop
        Returns:
            model: trained model
            avg_reward_hist: list with the average reward per episode at each epoch
            var_dict: dictionary with all locals, for logging/debugging purposes

        """

        # init everything
        # ==============================================================================
        # seed all our RNGs
        env = gym.make(self.env_name, **self.env_config)

        cur_total_steps = 0
        env.seed(self.seed)
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)

        progress_bar = tqdm.tqdm(total=total_steps)
        lr_lookup = make_schedule(self.lr_schedule, total_steps)

        self.sgd_lr = lr_lookup(0)

        progress_bar.update(0)
        early_stop = False
        self.pol_opt = torch.optim.RMSprop(self.model.policy.parameters(),
                                           lr=lr_lookup(cur_total_steps))
        self.val_opt = torch.optim.RMSprop(self.model.value_fn.parameters(),
                                           lr=lr_lookup(cur_total_steps))

        # Train until we hit our total steps or reach our reward threshold
        # ==============================================================================
        while cur_total_steps < total_steps:
            batch_obs = torch.empty(0)
            batch_act = torch.empty(0)
            batch_adv = torch.empty(0)
            batch_discrew = torch.empty(0)
            cur_batch_steps = 0

            # Bail out if we have met out reward threshold
            if len(self.raw_rew_hist) > 2 and self.reward_stop:
                if self.raw_rew_hist[
                        -1] >= self.reward_stop and self.raw_rew_hist[
                            -2] >= self.reward_stop:
                    early_stop = True
                    break

            # construct batch data from rollouts
            # ==============================================================================
            while cur_batch_steps < self.epoch_batch_size:
                ep_obs, ep_act, ep_rew, ep_steps, ep_term = do_rollout(
                    env, self.model, self.env_no_term_steps)

                cur_batch_steps += ep_steps
                cur_total_steps += ep_steps

                #print(sum(ep_rew).item())
                self.raw_rew_hist.append(sum(ep_rew).item())
                #print("Rew:", sum(ep_rew).item())
                batch_obs = torch.cat((batch_obs, ep_obs.clone()))
                batch_act = torch.cat((batch_act, ep_act.clone()))

                if self.normalize_return:
                    self.rew_std = update_std(ep_rew, self.rew_std,
                                              cur_total_steps)
                    ep_rew = ep_rew / (self.rew_std + 1e-6)

                if ep_term:
                    ep_rew = torch.cat((ep_rew, torch.zeros(1, 1)))
                else:
                    ep_rew = torch.cat((ep_rew, self.model.value_fn(
                        ep_obs[-1]).detach().reshape(1, 1).clone()))

                ep_discrew = discount_cumsum(ep_rew, self.gamma)[:-1]
                batch_discrew = torch.cat((batch_discrew, ep_discrew.clone()))

                with torch.no_grad():
                    ep_val = torch.cat((self.model.value_fn(ep_obs),
                                        ep_rew[-1].reshape(1, 1).clone()))
                    deltas = ep_rew[:-1] + self.gamma * ep_val[1:] - ep_val[:-1]

                ep_adv = discount_cumsum(deltas, self.gamma * self.lam)
                # make sure our advantages are zero mean and unit variance

                batch_adv = torch.cat((batch_adv, ep_adv.clone()))

            # PostProcess epoch and update weights
            # ==============================================================================
            if self.normalize_adv:
                # adv_mean = update_mean(batch_adv, adv_mean, cur_total_steps)
                # adv_var = update_std(batch_adv, adv_var, cur_total_steps)
                batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() +
                                                              1e-6)

            # Update the policy using the PPO loss
            for pol_epoch in range(self.sgd_epochs):
                pol_loss, approx_kl = self.policy_update(
                    batch_act, batch_obs, batch_adv)
                if approx_kl > self.target_kl:
                    print("KL Stop")
                    break

            for val_epoch in range(self.sgd_epochs):
                val_loss = self.value_update(batch_obs, batch_discrew)

            # update observation mean and variance

            if self.normalize_obs:
                self.obs_mean = update_mean(batch_obs, self.obs_mean,
                                            cur_total_steps)
                self.obs_std = update_std(batch_obs, self.obs_std,
                                          cur_total_steps)
                self.model.policy.state_means = self.obs_mean
                self.model.value_fn.state_means = self.obs_mean
                self.model.policy.state_std = self.obs_std
                self.model.value_fn.state_std = self.obs_std

            sgd_lr = lr_lookup(cur_total_steps)

            self.old_model = copy.deepcopy(self.model)
            self.val_loss_hist.append(val_loss.detach())
            self.pol_loss_hist.append(pol_loss.detach())
            self.lrp_hist.append(
                self.pol_opt.state_dict()['param_groups'][0]['lr'])
            self.lrv_hist.append(
                self.val_opt.state_dict()['param_groups'][0]['lr'])
            self.kl_hist.append(approx_kl.detach())
            self.entropy_hist.append(self.model.policy.logstds.detach())

            progress_bar.update(cur_batch_steps)

        progress_bar.close()
        return self.model, self.raw_rew_hist, locals()
Beispiel #7
0
def ppo_dim(
        env_name,
        total_steps,
        model,
        transient_length=50,
        act_std_schedule=(0.7,),
        epoch_batch_size=2048,
        gamma=0.99,
        lam=0.95,
        eps=0.2,
        seed=0,
        entropy_coef=0.0,
        sgd_batch_size=1024,
        lr_schedule=(3e-4,),
        sgd_epochs=10,
        target_kl=float('inf'),
        val_coef=.5,
        clip_val=True,
        env_no_term_steps=0,
        use_gpu=False,
        reward_stop=None,
        normalize_return=True,
        normalize_obs=True,
        normalize_adv=True,
        env_config={}
):
    """
    Implements proximal policy optimization with clipping

    Args:
        env_name: name of the openAI gym environment to solve
        total_steps: number of timesteps to run the PPO for
        model: model from seagul.rl.models. Contains policy and value fn
        act_std_schedule: schedule to set the variance of the policy. Will linearly interpolate values
        epoch_batch_size: number of environment steps to take per batch, total steps will be num_epochs*epoch_batch_size
        seed: seed for all the rngs
        gamma: discount applied to future rewards, usually close to 1
        lam: lambda for the Advantage estimation, usually close to 1
        eps: epsilon for the clipping, usually .1 or .2
        sgd_batch_size: batch size for policy updates
        sgd_batch_size: batch size for value function updates
        lr_schedule: learning rate for policy pol_optimizer
        sgd_epochs: how many epochs to use for each policy update
        val_epochs: how many epochs to use for each value update
        target_kl: max KL before breaking
        use_gpu:  want to use the GPU? set to true
        reward_stop: reward value to stop if we achieve
        normalize_return: should we normalize the return?
        env_config: dictionary containing kwargs to pass to your the environment

    Returns:
        model: trained model
        avg_reward_hist: list with the average reward per episode at each epoch
        var_dict: dictionary with all locals, for logging/debugging purposes

    Example:
        from seagul.rl.algos import ppo
        from seagul.nn import MLP
        from seagul.rl.models import PPOModel
        import torch

        input_size = 3
        output_size = 1
        layer_size = 64
        num_layers = 2

        policy = MLP(input_size, output_size, num_layers, layer_size)
        value_fn = MLP(input_size, 1, num_layers, layer_size)
        model = PPOModel(policy, value_fn)

        model, rews, var_dict = ppo("Pendulum-v0", 10000, model)

    """

    # init everything
    # ==============================================================================
    torch.set_num_threads(1)

    env = gym.make(env_name, **env_config)
    if isinstance(env.action_space, gym.spaces.Box):
        act_size = env.action_space.shape[0]
        act_dtype = torch.double
    else:
        raise NotImplementedError("trying to use unsupported action space", env.action_space)

    actstd_lookup = make_schedule(act_std_schedule, total_steps)
    lr_lookup = make_schedule(lr_schedule, total_steps)

    model.action_var = actstd_lookup(0)
    sgd_lr = lr_lookup(0)

    obs_size = env.observation_space.shape[0]
    obs_mean = torch.zeros(obs_size)
    obs_std = torch.ones(obs_size)
    rew_mean = torch.zeros(1)
    rew_std = torch.ones(1)

    # copy.deepcopy broke for me with older version of torch. Using pickle for this is weird but works fine
    old_model = pickle.loads(pickle.dumps(model))

    # seed all our RNGs
    env.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    # set defaults, and decide if we are using a GPU or not
    use_cuda = torch.cuda.is_available() and use_gpu
    device = torch.device("cuda:0" if use_cuda else "cpu")

    # init logging stuff
    raw_rew_hist = []
    val_loss_hist = []
    pol_loss_hist = []
    progress_bar = tqdm.tqdm(total=total_steps)
    cur_total_steps = 0
    progress_bar.update(0)
    early_stop = False

    # Train until we hit our total steps or reach our reward threshold
    # ==============================================================================
    while cur_total_steps < total_steps:
        pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr)
        val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=sgd_lr)

        batch_obs = torch.empty(0)
        batch_act = torch.empty(0)
        batch_adv = torch.empty(0)
        batch_discrew = torch.empty(0)
        cur_batch_steps = 0

        # Bail out if we have met out reward threshold
        if len(raw_rew_hist) > 2 and reward_stop:
            if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop:
                early_stop = True
                break

        # construct batch data from rollouts
        # ==============================================================================
        while cur_batch_steps < epoch_batch_size:
            ep_obs, ep_act, ep_rew, ep_steps, ep_term = do_rollout(env, model, env_no_term_steps)
            ep_rew /= var_dim(ep_obs[transient_length:],order=1)


            raw_rew_hist.append(sum(ep_rew).item())
            batch_obs = torch.cat((batch_obs, ep_obs[:-1]))
            batch_act = torch.cat((batch_act, ep_act[:-1]))

            if not ep_term:
                ep_rew[-1] = model.value_fn(ep_obs[-1]).detach()

            ep_discrew = discount_cumsum(ep_rew, gamma)

            if normalize_return:
                rew_mean = update_mean(batch_discrew, rew_mean, cur_total_steps)
                rew_std = update_std(ep_discrew, rew_std, cur_total_steps)
                ep_discrew = ep_discrew / (rew_std + 1e-6)

            batch_discrew = torch.cat((batch_discrew, ep_discrew[:-1]))

            ep_val = model.value_fn(ep_obs)

            deltas = ep_rew[:-1] + gamma * ep_val[1:] - ep_val[:-1]
            ep_adv = discount_cumsum(deltas.detach(), gamma * lam)
            batch_adv = torch.cat((batch_adv, ep_adv))

            cur_batch_steps += ep_steps
            cur_total_steps += ep_steps

        # make sure our advantages are zero mean and unit variance
        if normalize_adv:
            #adv_mean = update_mean(batch_adv, adv_mean, cur_total_steps)
            #adv_var = update_std(batch_adv, adv_var, cur_total_steps)
            batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-6)


        num_mbatch = int(batch_obs.shape[0] / sgd_batch_size)
        # Update the policy using the PPO loss
        for pol_epoch in range(sgd_epochs):
            for i in range(num_mbatch):
                # policy update
                # ========================================================================
                cur_sample = i * sgd_batch_size

                # Transfer to GPU (if GPU is enabled, else this does nothing)
                local_obs = batch_obs[cur_sample:cur_sample + sgd_batch_size]
                local_act = batch_act[cur_sample:cur_sample + sgd_batch_size]
                local_adv = batch_adv[cur_sample:cur_sample + sgd_batch_size]
                local_val = batch_discrew[cur_sample:cur_sample + sgd_batch_size]

                # Compute the loss
                logp = model.get_logp(local_obs, local_act).reshape(-1, act_size)
                old_logp = old_model.get_logp(local_obs, local_act).reshape(-1, act_size)
                mean_entropy = -(logp*torch.exp(logp)).mean()

                r = torch.exp(logp - old_logp)
                clip_r = torch.clamp(r, 1 - eps, 1 + eps)

                pol_loss = -torch.min(r * local_adv, clip_r * local_adv).mean() - entropy_coef*mean_entropy

                approx_kl = ((logp - old_logp)**2).mean()
                if approx_kl > target_kl:
                    break

                pol_opt.zero_grad()
                pol_loss.backward()
                pol_opt.step()

                # value_fn update
                # ========================================================================
                val_preds = model.value_fn(local_obs)
                if clip_val:
                    old_val_preds = old_model.value_fn(local_obs)
                    val_preds_clipped = old_val_preds + torch.clamp(val_preds - old_val_preds, -eps, eps)
                    val_loss1 = (val_preds_clipped - local_val)**2
                    val_loss2 = (val_preds - local_val)**2
                    val_loss = val_coef*torch.max(val_loss1, val_loss2).mean()
                else:
                    val_loss = val_coef*((val_preds - local_val) ** 2).mean()

                val_opt.zero_grad()
                val_loss.backward()
                val_opt.step()

        # update observation mean and variance

        if normalize_obs:
            obs_mean = update_mean(batch_obs, obs_mean, cur_total_steps)
            obs_std = update_std(batch_obs, obs_std, cur_total_steps)
            model.policy.state_means = obs_mean
            model.value_fn.state_means = obs_mean
            model.policy.state_std = obs_std
            model.value_fn.state_std = obs_std

        model.action_std = actstd_lookup(cur_total_steps)
        sgd_lr = lr_lookup(cur_total_steps)

        old_model = pickle.loads(pickle.dumps(model))
        val_loss_hist.append(val_loss)
        pol_loss_hist.append(pol_loss)

        progress_bar.update(cur_batch_steps)

    progress_bar.close()
    return model, raw_rew_hist, locals()