Ejemplo n.º 1
0
    def __init__(self, env_fns, engine):
        # TODO: sharing cuda tensors requires spawn or forkserver but these do not work with mpi
        # mp.set_start_method('spawn')
        self.engine = engine

        self.waiting = False
        self.closed = False
        self.nb_env = len(env_fns)

        self.remotes, self.work_remotes = zip(
            *[mp.Pipe() for _ in range(self.nb_env)])
        self.ps = [
            mp.Process(target=worker,
                       args=(work_remote, remote, CloudpickleWrapper(env_fn)))
            for (work_remote, remote,
                 env_fn) in zip(self.work_remotes, self.remotes, env_fns)
        ]
        for p in self.ps:
            p.daemon = True  # if the main process crashes, we should not cause things to hang
            p.start()
        for remote in self.work_remotes:
            remote.close()

        self.remotes[0].send(('get_spaces', None))
        self._observation_space, self._action_space = self.remotes[0].recv()
        self.remotes[0].send(('get_processors', None))
        self._cpu_preprocessor, self._gpu_preprocessor = self.remotes[0].recv()

        shared_memories = []
        for remote in self.remotes:
            remote.send(('get_shared_memory', None))
            shared_memories.append(remote.recv())
        self.shared_memories = listd_to_dlist(shared_memories)
Ejemplo n.º 2
0
 def reset(self):
     for remote in self.remotes:
         remote.send(('reset', None))
     obs = listd_to_dlist([remote.recv() for remote in self.remotes])
     shared_mems = {k: torch.stack(v) for k, v in self.shared_memories.items()}
     obs = {**obs, **shared_mems}
     return obs
Ejemplo n.º 3
0
 def reset(self):
     for socket in self._zmq_sockets:
         socket.send('reset'.encode())
     obs = listd_to_dlist([json.loads(remote.recv().decode()) for remote in self._zmq_sockets])
     shared_mems = {k: torch.stack(v) for k, v in self.shared_memories.items()}
     obs = {**obs, **shared_mems}
     return obs
Ejemplo n.º 4
0
 def step_wait(self):
     results = [remote.recv() for remote in self.remotes]
     self.waiting = False
     obs, rews, dones, infos = zip(*results)
     obs = listd_to_dlist(obs)
     shared_mems = {k: torch.stack(v) for k, v in self.shared_memories.items()}
     obs = {**obs, **shared_mems}
     return obs, rews, dones, infos
Ejemplo n.º 5
0
    def __init__(self, args, log_id_dir, initial_step_count, rank):
        seed = args.seed \
            if rank == 0 \
            else args.seed + args.nb_env * rank
        print('Worker {} using seed {}'.format(rank, seed))

        # load saved registry classes
        REGISTRY.load_extern_classes(log_id_dir)

        # ENV
        engine = REGISTRY.lookup_engine(args.env)
        env_cls = REGISTRY.lookup_env(args.env)
        mgr_cls = REGISTRY.lookup_manager(args.manager)
        env_mgr = mgr_cls.from_args(args, engine, env_cls, seed=seed)

        # NETWORK
        torch.manual_seed(args.seed)
        device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
        output_space = REGISTRY.lookup_output_space(args.actor_worker,
                                                    env_mgr.action_space)
        if args.custom_network:
            net_cls = REGISTRY.lookup_network(args.custom_network)
        else:
            net_cls = ModularNetwork
        net = net_cls.from_args(args, env_mgr.observation_space, output_space,
                                env_mgr.gpu_preprocessor, REGISTRY)
        actor_cls = REGISTRY.lookup_actor(args.actor_worker)
        actor = actor_cls.from_args(args, env_mgr.action_space)
        builder = actor_cls.exp_spec_builder(env_mgr.observation_space,
                                             env_mgr.action_space,
                                             net.internal_space(),
                                             env_mgr.nb_env)
        exp = REGISTRY.lookup_exp(args.exp).from_args(args, builder)

        self.actor = actor
        self.exp = exp.to(device)
        self.nb_step = args.nb_step
        self.env_mgr = env_mgr
        self.nb_env = args.nb_env
        self.network = net.to(device)
        self.device = device
        self.initial_step_count = initial_step_count

        # TODO: this should be set to eval after some number of training steps
        self.network.train()

        # SETUP state variables for run
        self.step_count = self.initial_step_count
        self.global_step_count = self.initial_step_count
        self.ep_rewards = torch.zeros(self.nb_env)
        self.rank = rank

        self.obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
        self.internals = listd_to_dlist([
            self.network.new_internals(self.device) for _ in range(self.nb_env)
        ])
        self.start_time = time()
        self._weights_synced = False
Ejemplo n.º 6
0
    def __init__(self, network, device, reward_normalizer, gpu_preprocessor, nb_env, nb_rollout, discount,
                 minimum_importance_value=1.0, minimum_importance_policy=1.0, entropy_weight=0.01):
        self.discount = discount
        self.gpu_preprocessor = gpu_preprocessor
        self.minimum_importance_value = minimum_importance_value
        self.minimum_importance_policy = minimum_importance_policy
        self.entropy_weight = entropy_weight

        self._network = network.to(device)
        self._exp_cache = RolloutCache(nb_rollout, device, reward_normalizer, ['log_prob_of_action', 'sampled_action'])
        self._internals = listd_to_dlist([self.network.new_internals(device) for _ in range(nb_env)])
        self._device = device
        self.network.train()
Ejemplo n.º 7
0
    def step_wait(self):
        results = [remote.recv() for remote in self._zmq_sockets]
        self.waiting = False

        # check for errors and parse
        self._check_for_errors(results)
        results = [json.loads(res.decode()) for res in results]
        obs, rews, dones, infos = zip(*results)

        obs = listd_to_dlist(obs)
        shared_mems = {k: torch.stack(v) for k, v in self.shared_memories.items()}
        obs = {**obs, **shared_mems}
        return obs, torch.tensor(rews), torch.tensor(dones), infos
Ejemplo n.º 8
0
 def reset(self):
     obs = []
     for e in range(self.nb_env):
         ob = self.envs[e].reset()
         obs.append(ob)
     obs = listd_to_dlist(obs)
     new_obs = {}
     for k, v in dummy_handle_ob(obs).items():
         if self._is_tensor_key(k):
             new_obs[k] = torch.stack(v)
         else:
             new_obs[k] = v
     self.buf_obs = new_obs
     return self.buf_obs
Ejemplo n.º 9
0
    def __init__(self, env_fns, engine):
        super(SubProcEnvManager, self).__init__(env_fns, engine)
        self.waiting = False
        self.closed = False
        self.processes = []

        self._zmq_context = zmq.Context()
        self._zmq_ports = []
        self._zmq_sockets = []

        # make a temporary env to get stuff
        dummy = env_fns[0]()
        self._observation_space = dummy.observation_space
        self._action_space = dummy.action_space
        self._cpu_preprocessor = dummy.cpu_preprocessor
        self._gpu_preprocessor = dummy.gpu_preprocessor
        dummy.close()

        # Allows msgpack to work with NumPy
        m.patch()

        # iterate envs to get torch shared memory through pipe then close it
        shared_memories = []

        for w_ind in range(self.nb_env):
            pipe, w_pipe = mp.Pipe()
            socket, port = zmq_robust_bind_socket(self._zmq_context)

            process = mp.Process(
                target=worker,
                args=(w_pipe, pipe, port, CloudpickleWrapper(env_fns[w_ind])),
            )
            process.daemon = True
            process.start()
            self.processes.append(process)

            self._zmq_sockets.append(socket)

            pipe.send(("get_shared_memory", None))
            shared_memories.append(pipe.recv())

            # switch to zmq socket and close pipes
            pipe.send(("switch_zmq", None))
            pipe.close()
            w_pipe.close()

        self.shared_memories = listd_to_dlist(shared_memories)
Ejemplo n.º 10
0
    def step_wait(self):
        obs = []
        for e in range(self.nb_env):
            ob, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(self.actions[e])
            if self.buf_dones[e]:
                ob = self.envs[e].reset()
            obs.append(ob)
        obs = listd_to_dlist(obs)
        new_obs = {}
        for k, v in dummy_handle_ob(obs).items():
            if self._is_tensor_key(k):
                new_obs[k] = torch.stack(v)
            else:
                new_obs[k] = v
        self.buf_obs = new_obs

        return self.buf_obs, self.buf_rews, self.buf_dones, self.buf_infos
Ejemplo n.º 11
0
    def reset(self):
        """Tell all subprocess environments to reset to initial state.

        Returns
        -------
        obs : dict[str, torch.Tensor]
            Observation
        """
        for socket in self._zmq_sockets:
            socket.send(msgpack.dumps("reset"))
        obs = listd_to_dlist(
            [msgpack.loads(remote.recv()) for remote in self._zmq_sockets])
        shared_mems = {
            k: torch.stack(v)
            for k, v in self.shared_memories.items()
        }
        obs = {**obs, **shared_mems}
        return obs
Ejemplo n.º 12
0
 def _to_cpu(self, var):
     # TODO: this is a hack, should instead register a custom serializer for torch tensors to go
     # to CPU
     if isinstance(var, list):
         # list of dict -> dict of lists
         # observations/actions/internals
         if isinstance(var[0], dict):
             # if empty dict it doesn't matter
             if len(var[0]) == 0:
                 return {}
             first_v = next(iter(var[0].values()))
             # observations/actions
             if isinstance(first_v, torch.Tensor):
                 return {
                     k: torch.stack(v).cpu()
                     for k, v in listd_to_dlist(var).items()
                 }
             # internals
             elif isinstance(first_v, list):
                 # TODO: there's gotta be a better way to do this
                 assert len(var) == 1
                 return {
                     k: torch.stack(v).cpu().unsqueeze(0)
                     for k, v in var[0].items()
                 }
         # other actor stuff
         elif isinstance(var[0], torch.Tensor):
             return torch.stack(var).cpu()
         else:
             raise NotImplementedError(
                 "Expected rollout item to be a Tensor or dict(Tensors) got {}"
                 .format(type(var[0])))
     elif isinstance(var, dict):
         # next obs
         if isinstance(first_v, torch.Tensor):
             return {k: v.cpu() for k, v in var.items()}
         else:
             raise NotImplementedError(
                 "Expected rollout dict item to be a tensor got {}".format(
                     type(var)))
     else:
         raise NotImplementedError(
             "Expected rollout object to be a list got {}".format(
                 type(var)))
Ejemplo n.º 13
0
    def __init__(self,
                 network,
                 device,
                 reward_normalizer,
                 gpu_preprocessor,
                 nb_env,
                 nb_rollout,
                 discount,
                 gae,
                 tau,
                 normalize_advantage,
                 entropy_weight=0.01):
        self.discount, self.gae, self.tau = discount, gae, tau
        self.normalize_advantage = normalize_advantage
        self.entropy_weight = entropy_weight
        self.gpu_preprocessor = gpu_preprocessor

        self._network = network.to(device)
        self._exp_cache = RolloutCache(nb_rollout, device, reward_normalizer,
                                       ['values', 'log_probs', 'entropies'])
        self._internals = listd_to_dlist(
            [self.network.new_internals(device) for _ in range(nb_env)])
        self._device = device
        self.network.train()
Ejemplo n.º 14
0
    def run(self):
        step_count = self.initial_step_count
        next_save = self.init_next_save(self.initial_step_count,
                                        self.epoch_len)
        prev_step_t = time()
        ep_rewards = torch.zeros(self.nb_env)

        obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
        internals = listd_to_dlist([
            self.network.new_internals(self.device) for _ in range(self.nb_env)
        ])
        start_time = time()
        while step_count < self.nb_step:
            actions, internals = self.agent.act(self.network, obs, internals)
            next_obs, rewards, terminals, infos = self.env_mgr.step(actions)
            next_obs = dtensor_to_dev(next_obs, self.device)

            self.agent.observe(
                obs,
                rewards.to(self.device).float(),
                terminals.to(self.device).float(),
                infos,
            )

            # Perform state updates
            step_count += self.nb_env
            ep_rewards += rewards.float()
            obs = next_obs

            term_rewards, term_infos = [], []
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(
                            self.device).items():
                        internals[k][i] = v
                    term_rewards.append(ep_rewards[i].item())
                    if infos[i]:
                        term_infos.append(infos[i])
                    ep_rewards[i].zero_()

            if term_rewards:
                term_reward = np.mean(term_rewards)
                delta_t = time() - start_time
                self.logger.info("STEP: {} REWARD: {} STEP/S: {}".format(
                    step_count,
                    term_reward,
                    (step_count - self.initial_step_count) / delta_t,
                ))
                self.summary_writer.add_scalar("reward", term_reward,
                                               step_count)
                if term_infos:
                    float_keys = [
                        k for k, v in term_infos[0].items() if type(v) == float
                    ]
                    term_infos_dlist = listd_to_dlist(term_infos)
                    for k in float_keys:
                        self.summary_writer.add_scalar(
                            f"info/{k}",
                            np.mean(term_infos_dlist[k]),
                            step_count,
                        )

            if step_count >= next_save:
                self.saver.save_state_dicts(self.network, step_count,
                                            self.optimizer)
                next_save += self.epoch_len

            # Learn
            if self.agent.is_ready():
                loss_dict, metric_dict = self.agent.learn_step(
                    self.updater,
                    self.network,
                    next_obs,
                    internals,
                )
                total_loss = sum(loss_dict.values())

                epoch = step_count / self.nb_env
                self.scheduler.step(epoch)

                self.agent.clear()
                for k, vs in internals.items():
                    internals[k] = [v.detach() for v in vs]

                # write summaries
                cur_step_t = time()
                if cur_step_t - prev_step_t > self.summary_freq:
                    self.write_summaries(
                        self.summary_writer,
                        step_count,
                        total_loss,
                        loss_dict,
                        metric_dict,
                        self.network.named_parameters(),
                    )
                    prev_step_t = cur_step_t
Ejemplo n.º 15
0
    def run(self):
        nb_env = self.env_mgr.nb_env
        best_epoch_id = None
        overall_mean = -float("inf")
        for epoch_id in self.epoch_ids:
            best_mean = -float("inf")
            best_std = None
            selected_model = None
            reward_buf = torch.zeros(nb_env)
            for net_path in self.log_dir_helper.network_paths_at_epoch(
                epoch_id
            ):
                self.network.load_state_dict(
                    torch.load(
                        net_path, map_location=lambda storage, loc: storage
                    )
                )
                self.network.eval()

                internals = listd_to_dlist(
                    [
                        self.network.new_internals(self.device)
                        for _ in range(nb_env)
                    ]
                )
                episode_completes = [False for _ in range(nb_env)]
                next_obs = dtensor_to_dev(self.env_mgr.reset(), self.device)

                while not all(episode_completes):
                    obs = next_obs
                    with torch.no_grad():
                        actions, _, internals = self.actor.act(
                            self.network, obs, internals
                        )
                    next_obs, rewards, terminals, infos = self.env_mgr.step(
                        actions
                    )
                    next_obs = dtensor_to_dev(next_obs, self.device)

                    for i in range(self.env_mgr.nb_env):
                        if episode_completes[i]:
                            continue
                        elif terminals[i]:
                            reward_buf[i] += rewards[i]
                            episode_completes[i] = True
                        else:
                            reward_buf[i] += rewards[i]

                mean = reward_buf.mean().item()
                std = reward_buf.std().item()

                if mean >= best_mean:
                    best_mean = mean
                    best_std = std
                    selected_model = os.path.split(net_path)[-1]

            self.logger.info(
                f"EPOCH_ID: {epoch_id} "
                f"MEAN_REWARD: {best_mean} "
                f"STD_DEV: {best_std} "
                f"SELECTED_MODEL: {selected_model}"
            )
            with open(self.log_dir_helper.eval_path(), "a") as eval_f:
                eval_f.write(
                    f"{epoch_id},"
                    f"{best_mean},"
                    f"{best_std},"
                    f"{selected_model}\n"
                )

            if best_mean >= overall_mean:
                best_epoch_id = epoch_id
                overall_mean = best_mean
        self.logger.info(
            f"*** EPOCH_ID: {best_epoch_id} MEAN_REWARD: {overall_mean} ***"
        )
Ejemplo n.º 16
0
 def test_listd_to_dlist(self):
     assert listd_to_dlist([{'a': 1}]) == {'a': [1]}
Ejemplo n.º 17
0
 def test_listd_to_dlist(self):
     assert listd_to_dlist([{"a": 1}]) == {"a": [1]}