Ejemplo n.º 1
0
    def __init__(self, args, log_id_dir, initial_step_count, rank):
        seed = args.seed \
            if rank == 0 \
            else args.seed + args.nb_env * rank
        print('Worker {} using seed {}'.format(rank, seed))

        # load saved registry classes
        REGISTRY.load_extern_classes(log_id_dir)

        # ENV
        engine = REGISTRY.lookup_engine(args.env)
        env_cls = REGISTRY.lookup_env(args.env)
        mgr_cls = REGISTRY.lookup_manager(args.manager)
        env_mgr = mgr_cls.from_args(args, engine, env_cls, seed=seed)

        # NETWORK
        torch.manual_seed(args.seed)
        device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
        output_space = REGISTRY.lookup_output_space(args.actor_worker,
                                                    env_mgr.action_space)
        if args.custom_network:
            net_cls = REGISTRY.lookup_network(args.custom_network)
        else:
            net_cls = ModularNetwork
        net = net_cls.from_args(args, env_mgr.observation_space, output_space,
                                env_mgr.gpu_preprocessor, REGISTRY)
        actor_cls = REGISTRY.lookup_actor(args.actor_worker)
        actor = actor_cls.from_args(args, env_mgr.action_space)
        builder = actor_cls.exp_spec_builder(env_mgr.observation_space,
                                             env_mgr.action_space,
                                             net.internal_space(),
                                             env_mgr.nb_env)
        exp = REGISTRY.lookup_exp(args.exp).from_args(args, builder)

        self.actor = actor
        self.exp = exp.to(device)
        self.nb_step = args.nb_step
        self.env_mgr = env_mgr
        self.nb_env = args.nb_env
        self.network = net.to(device)
        self.device = device
        self.initial_step_count = initial_step_count

        # TODO: this should be set to eval after some number of training steps
        self.network.train()

        # SETUP state variables for run
        self.step_count = self.initial_step_count
        self.global_step_count = self.initial_step_count
        self.ep_rewards = torch.zeros(self.nb_env)
        self.rank = rank

        self.obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
        self.internals = listd_to_dlist([
            self.network.new_internals(self.device) for _ in range(self.nb_env)
        ])
        self.start_time = time()
        self._weights_synced = False
Ejemplo n.º 2
0
    def run(self):
        if not self._weights_synced:
            raise Exception("Must set weights before calling run")

        self.exp.clear()
        all_terminal_rewards = []
        all_terminal_infos = {}

        # loop to generate a rollout
        while not self.exp.is_ready():
            with torch.no_grad():
                actions, exp, self.internals = self.actor.act(
                    self.network, self.obs, self.internals)

            self.exp.write_actor(exp)

            next_obs, rewards, terminals, infos = self.env_mgr.step(actions)
            next_obs = dtensor_to_dev(next_obs, self.device)
            self.exp.write_env(self.obs, rewards.float(), terminals.float(),
                               infos)

            # Perform state updates
            self.step_count += self.nb_env
            self.ep_rewards += rewards.float()
            self.obs = next_obs

            term_rewards = []
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(
                            self.device).items():
                        self.internals[k][i] = v
                    rew = self.ep_rewards[i].item()
                    term_rewards.append(rew)
                    self.ep_rewards[i].zero_()

                    for k, v in infos[i].items():
                        if k not in all_terminal_infos:
                            all_terminal_infos[k] = []
                        all_terminal_infos[k].append(v)

            # avg rewards
            if term_rewards:
                term_reward = np.mean(term_rewards)
                all_terminal_rewards.append(term_reward)

                delta_t = time() - self.start_time
                print("RANK: {} "
                      "LOCAL STEP: {} "
                      "REWARD: {} "
                      "LOCAL STEP/S: {:.2f}".format(
                          self.rank,
                          self.step_count,
                          term_reward,
                          (self.step_count - self.initial_step_count) /
                          delta_t,
                      ))

        # rollout is full return it
        self.exp.write_next_obs(self.obs)
        # TODO: compression?
        if len(all_terminal_rewards) > 0:
            return {
                "rollout": self._ray_pack(self.exp),
                "terminal_rewards": np.mean(all_terminal_rewards),
                "terminal_infos":
                {k: np.mean(v)
                 for k, v in all_terminal_infos.items()},
            }
        else:
            return {
                "rollout": self._ray_pack(self.exp),
                "terminal_rewards": None,
                "terminal_infos": None,
            }
Ejemplo n.º 3
0
    def run(self):
        nb_env = self.env_mgr.nb_env
        best_epoch_id = None
        overall_mean = -float("inf")
        for epoch_id in self.epoch_ids:
            best_mean = -float("inf")
            best_std = None
            selected_model = None
            reward_buf = torch.zeros(nb_env)
            for net_path in self.log_dir_helper.network_paths_at_epoch(
                epoch_id
            ):
                self.network.load_state_dict(
                    torch.load(
                        net_path, map_location=lambda storage, loc: storage
                    )
                )
                self.network.eval()

                internals = listd_to_dlist(
                    [
                        self.network.new_internals(self.device)
                        for _ in range(nb_env)
                    ]
                )
                episode_completes = [False for _ in range(nb_env)]
                next_obs = dtensor_to_dev(self.env_mgr.reset(), self.device)

                while not all(episode_completes):
                    obs = next_obs
                    with torch.no_grad():
                        actions, _, internals = self.actor.act(
                            self.network, obs, internals
                        )
                    next_obs, rewards, terminals, infos = self.env_mgr.step(
                        actions
                    )
                    next_obs = dtensor_to_dev(next_obs, self.device)

                    for i in range(self.env_mgr.nb_env):
                        if episode_completes[i]:
                            continue
                        elif terminals[i]:
                            reward_buf[i] += rewards[i]
                            episode_completes[i] = True
                        else:
                            reward_buf[i] += rewards[i]

                mean = reward_buf.mean().item()
                std = reward_buf.std().item()

                if mean >= best_mean:
                    best_mean = mean
                    best_std = std
                    selected_model = os.path.split(net_path)[-1]

            self.logger.info(
                f"EPOCH_ID: {epoch_id} "
                f"MEAN_REWARD: {best_mean} "
                f"STD_DEV: {best_std} "
                f"SELECTED_MODEL: {selected_model}"
            )
            with open(self.log_dir_helper.eval_path(), "a") as eval_f:
                eval_f.write(
                    f"{epoch_id},"
                    f"{best_mean},"
                    f"{best_std},"
                    f"{selected_model}\n"
                )

            if best_mean >= overall_mean:
                best_epoch_id = epoch_id
                overall_mean = best_mean
        self.logger.info(
            f"*** EPOCH_ID: {best_epoch_id} MEAN_REWARD: {overall_mean} ***"
        )
Ejemplo n.º 4
0
    def run(self):
        step_count = self.initial_step_count
        next_save = self.init_next_save(self.initial_step_count,
                                        self.epoch_len)
        prev_step_t = time()
        ep_rewards = torch.zeros(self.nb_env)

        obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
        internals = listd_to_dlist([
            self.network.new_internals(self.device) for _ in range(self.nb_env)
        ])
        start_time = time()
        while step_count < self.nb_step:
            actions, internals = self.agent.act(self.network, obs, internals)
            next_obs, rewards, terminals, infos = self.env_mgr.step(actions)
            next_obs = dtensor_to_dev(next_obs, self.device)

            self.agent.observe(
                obs,
                rewards.to(self.device).float(),
                terminals.to(self.device).float(),
                infos,
            )

            # Perform state updates
            step_count += self.nb_env
            ep_rewards += rewards.float()
            obs = next_obs

            term_rewards, term_infos = [], []
            for i, terminal in enumerate(terminals):
                if terminal:
                    for k, v in self.network.new_internals(
                            self.device).items():
                        internals[k][i] = v
                    term_rewards.append(ep_rewards[i].item())
                    if infos[i]:
                        term_infos.append(infos[i])
                    ep_rewards[i].zero_()

            if term_rewards:
                term_reward = np.mean(term_rewards)
                delta_t = time() - start_time
                self.logger.info("STEP: {} REWARD: {} STEP/S: {}".format(
                    step_count,
                    term_reward,
                    (step_count - self.initial_step_count) / delta_t,
                ))
                self.summary_writer.add_scalar("reward", term_reward,
                                               step_count)
                if term_infos:
                    float_keys = [
                        k for k, v in term_infos[0].items() if type(v) == float
                    ]
                    term_infos_dlist = listd_to_dlist(term_infos)
                    for k in float_keys:
                        self.summary_writer.add_scalar(
                            f"info/{k}",
                            np.mean(term_infos_dlist[k]),
                            step_count,
                        )

            if step_count >= next_save:
                self.saver.save_state_dicts(self.network, step_count,
                                            self.optimizer)
                next_save += self.epoch_len

            # Learn
            if self.agent.is_ready():
                loss_dict, metric_dict = self.agent.learn_step(
                    self.updater,
                    self.network,
                    next_obs,
                    internals,
                )
                total_loss = sum(loss_dict.values())

                epoch = step_count / self.nb_env
                self.scheduler.step(epoch)

                self.agent.clear()
                for k, vs in internals.items():
                    internals[k] = [v.detach() for v in vs]

                # write summaries
                cur_step_t = time()
                if cur_step_t - prev_step_t > self.summary_freq:
                    self.write_summaries(
                        self.summary_writer,
                        step_count,
                        total_loss,
                        loss_dict,
                        metric_dict,
                        self.network.named_parameters(),
                    )
                    prev_step_t = cur_step_t