コード例 #1
0
    def step_wait(self):
        obs = []
        for e in range(self.nb_env):
            (
                ob,
                self.buf_rews[e],
                self.buf_dones[e],
                self.buf_infos[e],
            ) = self.envs[e].step(self.actions[e])
            if self.buf_dones[e]:
                ob = self.envs[e].reset()
            obs.append(ob)
        obs = listd_to_dlist(obs)
        new_obs = {}
        for k, v in dummy_handle_ob(obs).items():
            if self._is_tensor_key(k):
                new_obs[k] = torch.stack(v)
            else:
                new_obs[k] = v
        self.buf_obs = new_obs

        return (
            self.buf_obs,
            torch.tensor(self.buf_rews),
            torch.tensor(self.buf_dones),
            self.buf_infos,
        )
コード例 #2
0
    def run(self):
        for epoch_id in self.epoch_ids:
            reward_buf = 0
            for net_path in self.log_dir_helper.network_paths_at_epoch(
                    epoch_id):
                self.network.load_state_dict(
                    torch.load(net_path,
                               map_location=lambda storage, loc: storage))
                self.network.eval()

                internals = listd_to_dlist(
                    [self.network.new_internals(self.device)])
                next_obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
                self.env_mgr.render()

                episode_complete = False
                while not episode_complete:
                    obs = next_obs
                    with torch.no_grad():
                        actions, _, internals = self.actor.act(
                            self.network, obs, internals)
                    next_obs, rewards, terminals, infos = self.env_mgr.step(
                        actions)
                    self.env_mgr.render()
                    next_obs = dtensor_to_dev(next_obs, self.device)

                    reward_buf += rewards[0]

                    if terminals[0]:
                        episode_complete = True

                print(f"EPOCH_ID: {epoch_id} REWARD: {reward_buf}")
コード例 #3
0
 def reset(self):
     obs = []
     for e in range(self.nb_env):
         ob = self.envs[e].reset()
         obs.append(ob)
     obs = listd_to_dlist(obs)
     new_obs = {}
     for k, v in dummy_handle_ob(obs).items():
         if self._is_tensor_key(k):
             new_obs[k] = torch.stack(v)
         else:
             new_obs[k] = v
     self.buf_obs = new_obs
     return self.buf_obs
コード例 #4
0
    def read(self):
        exp_list, last_obs, is_weights = self._sample()
        exp_dev_list = [
            self._exp_to_dev(e, self.target_device) for e in exp_list
        ]
        # will be list of dicts, convert to dict of lists
        dict_of_list = listd_to_dlist(exp_dev_list)
        # get next obs
        dict_of_list["next_observation"] = last_obs
        # importance sampling weights
        dict_of_list["importance_sample_weights"] = is_weights

        # return named tuple
        return namedtuple(
            self.__class__.__name__,
            ["importance_sample_weights", "next_observation"] + self._keys,
        )(**dict_of_list)
コード例 #5
0
    def act_batch(self, network, batch_obs, batch_terminals, batch_actions,
                  internals, device):
        exp_cache = []

        for obs, actions, terminals in zip(batch_obs, batch_actions,
                                           batch_terminals):
            preds, internals, _ = network(obs, internals)
            exp_cache.append(self._process_exp(preds, actions))

            # where returns a single element tuple with the indexes
            terminal_inds = np.where(terminals)[0]
            for i in terminal_inds:
                for k, v in network.new_internals(device).items():
                    internals[k][i] = v

        exp = listd_to_dlist(exp_cache)
        return torch.stack(exp['log_probs']), torch.stack(
            exp['values']), torch.stack(exp['entropies'])
コード例 #6
0
    def run(self):
        local_step_count = global_step_count = self.initial_step_count
        ep_rewards = torch.zeros(self.nb_env)

        obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
        internals = listd_to_dlist(
            [
                self.network.new_internals(self.device)
                for _ in range(self.nb_env)
            ]
        )
        start_time = time()
        while global_step_count < self.nb_step:
            actions, internals = self.agent.act(self.network, obs, internals)
            next_obs, rewards, terminals, infos = self.env_mgr.step(actions)
            next_obs = dtensor_to_dev(next_obs, self.device)

            self.agent.observe(
                obs,
                rewards.to(self.device).float(),
                terminals.to(self.device).float(),
                infos,
            )
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v

            # Perform state updates
            local_step_count += self.nb_env
            global_step_count += self.nb_env * self.world_size
            ep_rewards += rewards.float()
            obs = next_obs

            term_rewards = []
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v
                    term_rewards.append(ep_rewards[i].item())
                    ep_rewards[i].zero_()

            if term_rewards:
                term_reward = np.mean(term_rewards)
                delta_t = time() - start_time
                self.logger.info(
                    "RANK: {} "
                    "GLOBAL STEP: {} "
                    "REWARD: {} "
                    "GLOBAL STEP/S: {} "
                    "LOCAL STEP/S: {}".format(
                        self.global_rank,
                        global_step_count,
                        term_reward,
                        (global_step_count - self.initial_step_count) / delta_t,
                        (local_step_count - self.initial_step_count) / delta_t,
                    )
                )

            # Learn
            if self.agent.is_ready():
                _, _ = self.agent.learn_step(
                    self.updater, self.network, next_obs, internals
                )

                self.agent.clear()
                for k, vs in internals.items():
                    internals[k] = [v.detach() for v in vs]
コード例 #7
0
    def run(self):
        local_step_count = global_step_count = self.initial_step_count
        next_save = self.init_next_save(self.initial_step_count, self.epoch_len)
        prev_step_t = time()
        ep_rewards = torch.zeros(self.nb_env)

        obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
        internals = listd_to_dlist(
            [
                self.network.new_internals(self.device)
                for _ in range(self.nb_env)
            ]
        )
        start_time = time()
        while global_step_count < self.nb_step:
            actions, internals = self.agent.act(self.network, obs, internals)
            next_obs, rewards, terminals, infos = self.env_mgr.step(actions)
            next_obs = dtensor_to_dev(next_obs, self.device)

            self.agent.observe(
                obs,
                rewards.to(self.device).float(),
                terminals.to(self.device).float(),
                infos,
            )
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v

            # Perform state updates
            local_step_count += self.nb_env
            global_step_count += self.nb_env * self.world_size
            ep_rewards += rewards.float()
            obs = next_obs

            term_rewards = []
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v
                    term_rewards.append(ep_rewards[i].item())
                    ep_rewards[i].zero_()

            if term_rewards:
                term_reward = np.mean(term_rewards)
                delta_t = time() - start_time
                self.logger.info(
                    "RANK: {} "
                    "GLOBAL STEP: {} "
                    "REWARD: {} "
                    "GLOBAL STEP/S: {} "
                    "LOCAL STEP/S: {}".format(
                        self.global_rank,
                        global_step_count,
                        term_reward,
                        (global_step_count - self.initial_step_count) / delta_t,
                        (local_step_count - self.initial_step_count) / delta_t,
                    )
                )
                self.summary_writer.add_scalar(
                    "reward", term_reward, global_step_count
                )

            if global_step_count >= next_save:
                self.saver.save_state_dicts(
                    self.network, global_step_count, self.optimizer
                )
                next_save += self.epoch_len

            # Learn
            if self.agent.is_ready():
                loss_dict, metric_dict = self.agent.learn_step(
                    self.updater, self.network, next_obs, internals
                )
                total_loss = torch.sum(
                    torch.stack(tuple(loss for loss in loss_dict.values()))
                )

                self.agent.clear()
                for k, vs in internals.items():
                    internals[k] = [v.detach() for v in vs]

                # write summaries
                cur_step_t = time()
                if cur_step_t - prev_step_t > self.summary_freq:
                    self.write_summaries(
                        self.summary_writer,
                        global_step_count,
                        total_loss,
                        loss_dict,
                        metric_dict,
                        self.network.named_parameters(),
                    )
                    prev_step_t = cur_step_t
コード例 #8
0
    def run(self, workers, profile=False):
        if profile:
            try:
                from pyinstrument import Profiler
            except:
                raise ImportError('You must install pyinstrument to use profiling.')
            profiler = Profiler()
            profiler.start()

        # setup queuer
        rollout_queuer = RolloutQueuerAsync(workers, self.nb_learn_batch, self.rollout_queue_size)
        rollout_queuer.start()

        # initial setup
        global_step_count = self.initial_step_count
        next_save = self.init_next_save(self.initial_step_count, self.epoch_len)
        prev_step_t = time()
        ep_rewards = torch.zeros(self.nb_env)
        start_time = time()

        # loop until total number steps
        print('{} starting training'.format(self.rank))
        while not self.done(global_step_count):
            self.exp.clear()
            # Get batch from queue
            rollouts, terminal_rewards, terminal_infos = rollout_queuer.get()

            # Iterate forward on batch
            self.exp.write_exps(rollouts)
            # keep a copy of terminals on the cpu it's faster
            rollout_terminals = torch.stack(self.exp['terminals']).numpy()
            self.exp.to(self.device)
            r = self.exp.read()
            internals = {k: ts[0].unbind(0) for k, ts in r.internals.items()}
            for obs, rewards, terminals in zip(
                    r.observations,
                    r.rewards,
                    rollout_terminals
            ):
                _, h_exp, internals = self.actor.act(self.network, obs,
                                                     internals)
                self.exp.write_actor(h_exp, no_env=True)

                # where returns a single element tuple with the indexes
                terminal_inds = np.where(terminals)[0]
                for i in terminal_inds:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v

            # compute loss
            loss_dict, metric_dict = self.learner.compute_loss(
                self.network, self.exp.read(), r.next_observation, internals
            )
            total_loss = torch.sum(
                torch.stack(tuple(loss for loss in loss_dict.values()))
            )

            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            # Perform state updates
            global_step_count += self.nb_env * self.nb_learn_batch * len(r.terminals) * self.nb_learners

            # if rank 0 write summaries and save
            # and send parameters to workers async
            if self.rank == 0:
                # TODO: this could be parallelized, chunk by nb learners
                self.synchronize_worker_parameters(workers, global_step_count)

                # possible save
                if global_step_count >= next_save:
                    self.saver.save_state_dicts(
                        self.network, global_step_count, self.optimizer
                    )
                    next_save += self.epoch_len

                # write reward summaries
                if any(terminal_rewards):
                    terminal_rewards = list(filter(lambda x: x is not None, terminal_rewards))
                    self.summary_writer.add_scalar(
                        'reward', np.mean(terminal_rewards), global_step_count
                    )

                # write infos
                if any(terminal_infos):
                    terminal_infos = list(filter(lambda x: x is not None, terminal_infos))
                    float_keys = [
                        k for k, v in terminal_infos[0].items() if type(v) == float
                    ]
                    terminal_infos_dlist = listd_to_dlist(terminal_infos)
                    for k in float_keys:
                        self.summary_writer.add_scalar(
                            f'info/{k}',
                            np.mean(terminal_infos_dlist[k]),
                            global_step_count
                        )

            # write summaries
            cur_step_t = time()
            if cur_step_t - prev_step_t > self.summary_freq:
                print('Rank {} Metrics:'.format(self.rank), rollout_queuer.metrics())
                if self.rank == 0:
                    self.write_summaries(
                        self.summary_writer, global_step_count, total_loss,
                        loss_dict, metric_dict, self.network.named_parameters()
                    )
                prev_step_t = cur_step_t

        rollout_queuer.close()
        print('{} stopped training'.format(self.rank))
        if profile:
            profiler.stop()
            print(profiler.output_text(unicode=True, color=True))
コード例 #9
0
    def run(self):
        local_step_count = global_step_count = self.initial_step_count
        ep_rewards = torch.zeros(self.nb_env)

        obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
        internals = listd_to_dlist([
            self.network.new_internals(self.device) for _ in
            range(self.nb_env)
        ])
        start_time = time()
        while global_step_count < self.nb_step:
            actions, internals = self.agent.act(self.network, obs, internals)
            next_obs, rewards, terminals, infos = self.env_mgr.step(actions)
            next_obs = dtensor_to_dev(next_obs, self.device)

            self.agent.observe(
                obs,
                rewards.to(self.device).float(),
                terminals.to(self.device).float(),
                infos
            )
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v

            # Perform state updates
            local_step_count += self.nb_env
            global_step_count += self.nb_env * self.world_size
            ep_rewards += rewards.float()
            obs = next_obs

            term_rewards = []
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(self.device).items():
                        internals[k][i] = v
                    term_rewards.append(ep_rewards[i].item())
                    ep_rewards[i].zero_()

            if term_rewards:
                term_reward = np.mean(term_rewards)
                delta_t = time() - start_time
                self.logger.info(
                    'RANK: {} '
                    'GLOBAL STEP: {} '
                    'REWARD: {} '
                    'GLOBAL STEP/S: {} '
                    'LOCAL STEP/S: {}'.format(
                        self.global_rank,
                        global_step_count,
                        term_reward,
                        (global_step_count - self.initial_step_count) / delta_t,
                        (local_step_count - self.initial_step_count) / delta_t
                    )
                )

            # Learn
            if self.agent.is_ready():
                loss_dict, metric_dict = self.agent.compute_loss(
                    self.network, next_obs, internals
                )
                total_loss = torch.sum(
                    torch.stack(tuple(loss for loss in loss_dict.values()))
                )

                self.optimizer.zero_grad()
                total_loss.backward()
                dist.barrier()
                handles = []
                for param in self.network.parameters():
                    handles.append(
                        dist.all_reduce(param.grad, async_op=True))
                for handle in handles:
                    handle.wait()
                # for param in self.network.parameters():
                #     param.grad.mul_(1. / self.world_size)
                self.optimizer.step()

                self.agent.clear()
                for k, vs in internals.items():
                    internals[k] = [v.detach() for v in vs]