示例#1
0
    def __init__(
        self,
        source_ids: List[str],
        all_preprocessors: List[Union[Preprocessor, Builder[Preprocessor]]],
        all_sensors: List[Sensor],
    ) -> None:
        """Initializer.

        # Parameters

        source_ids : The sensors and preprocessors that will be included in the set.
        all_preprocessors : The entire list of preprocessors to be executed.
        all_sensors : The entire list of sensors.
        """

        self.graph = PreprocessorGraph(all_preprocessors)

        self.source_ids = source_ids
        assert len(set(self.source_ids)) == len(
            self.source_ids), "No duplicated uuids allowed in source_ids"

        sensor_spaces = SensorSuite(all_sensors).observation_spaces
        preprocessor_spaces = self.graph.observation_spaces
        spaces: OrderedDict[str, gym.Space] = OrderedDict()
        for uuid in self.source_ids:
            assert (
                uuid in sensor_spaces.spaces
                or uuid in preprocessor_spaces.spaces
            ), "uuid {} missing from sensor suite and preprocessor graph".format(
                uuid)
            if uuid in sensor_spaces.spaces:
                spaces[uuid] = sensor_spaces[uuid]
            else:
                spaces[uuid] = preprocessor_spaces[uuid]
        self.observation_spaces = SpaceDict(spaces=spaces)
    def __init__(
            self,
            source_observation_spaces: SpaceDict,
            preprocessors: Sequence[Union[Preprocessor,
                                          Builder[Preprocessor]]],
            additional_output_uuids: Sequence[str] = tuple(),
    ) -> None:
        """Initializer.

        # Parameters

        source_observation_spaces : The observation spaces of all sensors before preprocessing.
            This generally should be the output of `SensorSuite.observation_spaces`.
        preprocessors : The preprocessors that will be included in the graph.
        additional_output_uuids: As described in the documentation for this class, the observations
            returned when calling `get_observations` only include, by default, those observations
            that are not processed by any preprocessor. If you'd like to include observations that
            would otherwise not be included, the uuids of these sensors should be included as
            a sequence of strings here.
        """
        self.device: torch.device = torch.device("cpu")

        obs_spaces: Dict[str, gym.Space] = {
            k: source_observation_spaces[k]
            for k in source_observation_spaces
        }

        self.preprocessors: Dict[str, Preprocessor] = {}
        for preprocessor in preprocessors:
            if isinstance(preprocessor, Builder):
                preprocessor = preprocessor()

            assert (preprocessor.uuid not in self.preprocessors
                    ), "'{}' is duplicated preprocessor uuid".format(
                        preprocessor.uuid)

            self.preprocessors[preprocessor.uuid] = preprocessor
            obs_spaces[preprocessor.uuid] = preprocessor.observation_space

        g = nx.DiGraph()
        for k in obs_spaces:
            g.add_node(k)
        for k in self.preprocessors:
            for j in self.preprocessors[k].input_uuids:
                g.add_edge(j, k)

        assert nx.is_directed_acyclic_graph(
            g), "preprocessors do not form a direct acyclic graph"

        self.observation_spaces = SpaceDict(
            spaces={
                uuid: obs_spaces[uuid]
                for uuid in obs_spaces
                if uuid in additional_output_uuids or g.out_degree(uuid) == 0
            })

        # ensure dependencies are precomputed
        self.compute_order = [n for n in nx.dfs_postorder_nodes(g)]
示例#3
0
    def __init__(self, sensors: Iterable[Sensor]) -> None:
        """Constructor

        :param sensors: list containing sensors for the environment, uuid of
            each sensor must be unique.
        """
        self.sensors = OrderedDict()
        spaces: OrderedDict[str, Space] = OrderedDict()
        for sensor in sensors:
            assert (
                sensor.uuid not in self.sensors
            ), "'{}' is duplicated sensor uuid".format(sensor.uuid)
            self.sensors[sensor.uuid] = sensor
            spaces[sensor.uuid] = sensor.observation_space
        self.observation_spaces = SpaceDict(spaces=spaces)
示例#4
0
    def __init__(self, sensors: Sequence[Sensor]) -> None:
        """Initializer.

        # Parameters

        param sensors: the sensors that will be included in the suite.
        """
        self.sensors = OrderedDict()
        spaces: OrderedDict[str, gym.Space] = OrderedDict()
        for sensor in sensors:
            assert (
                sensor.uuid
                not in self.sensors), "'{}' is duplicated sensor uuid".format(
                    sensor.uuid)
            self.sensors[sensor.uuid] = sensor
            spaces[sensor.uuid] = sensor.observation_space
        self.observation_spaces = SpaceDict(spaces=spaces)
示例#5
0
    def __init__(
        self,
        preprocessors: List[Union[Preprocessor, Builder[Preprocessor]]],
    ) -> None:
        """Initializer.

        # Parameters

        preprocessors : The preprocessors that will be included in the graph.
        """
        self.preprocessors: Dict[str, Preprocessor] = OrderedDict()
        spaces: OrderedDict[str, gym.Space] = OrderedDict()
        for preprocessor in preprocessors:
            if isinstance(preprocessor, Builder):
                preprocessor = preprocessor()

            assert (preprocessor.uuid not in self.preprocessors
                    ), "'{}' is duplicated preprocessor uuid".format(
                        preprocessor.uuid)
            self.preprocessors[preprocessor.uuid] = preprocessor
            spaces[preprocessor.uuid] = preprocessor.observation_space
        self.observation_spaces = SpaceDict(spaces=spaces)

        g = nx.DiGraph()
        for k in self.preprocessors:
            g.add_node(k)
        for k in self.preprocessors:
            for j in self.preprocessors[k].input_uuids:
                if j not in g:
                    g.add_node(j)
                g.add_edge(k, j)
        assert nx.is_directed_acyclic_graph(
            g), "preprocessors do not form a direct acyclic graph"

        # ensure dependencies are precomputed
        self.compute_order = [n for n in nx.dfs_postorder_nodes(g)]
示例#6
0
    def train(self) -> None:
        r"""Main method for DD-PPO.

        Returns:
            None
        """
        import apex

        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend
        )
        # add_signal_handlers()
        self.timing = Timing()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker", tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        set_cpus(self.local_rank, self.world_size)

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += self.world_rank * self.config.SIM_BATCH_SIZE
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        double_buffered = False
        self._num_worker_groups = self.config.NUM_PARALLEL_SCENES

        self._depth = self.config.DEPTH
        self._color = self.config.COLOR

        if self.config.TASK.lower() == "pointnav":
            self.observation_space = SpaceDict(
                {
                    "pointgoal_with_gps_compass": spaces.Box(
                        low=0.0, high=1.0, shape=(2,), dtype=np.float32
                    )
                }
            )
        else:
            self.observation_space = SpaceDict({})

        self.action_space = spaces.Discrete(4)

        if self._color:
            self.observation_space = SpaceDict(
                {
                    "rgb": spaces.Box(
                        low=np.finfo(np.float32).min,
                        high=np.finfo(np.float32).max,
                        shape=(3, *self.config.RESOLUTION),
                        dtype=np.uint8,
                    ),
                    **self.observation_space.spaces,
                }
            )

        if self._depth:
            self.observation_space = SpaceDict(
                {
                    "depth": spaces.Box(
                        low=np.finfo(np.float32).min,
                        high=np.finfo(np.float32).max,
                        shape=(1, *self.config.RESOLUTION),
                        dtype=np.float32,
                    ),
                    **self.observation_space.spaces,
                }
            )

        ppo_cfg = self.config.RL.PPO
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER) and self.world_rank == 0:
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        self.count_steps = 0
        burn_steps = 0
        burn_time = 0
        count_checkpoints = 0
        prev_time = 0
        self.update = 0

        LR_SCALE = (
            max(
                np.sqrt(
                    ppo_cfg.num_steps
                    * self.config.SIM_BATCH_SIZE
                    * ppo_cfg.num_accumulate_steps
                    / ppo_cfg.num_mini_batch
                    * self.world_size
                    / (128 * 2)
                ),
                1.0,
            )
            if (self.config.RL.DDPPO.scale_lr and not self.config.RL.PPO.ada_scale)
            else 1.0
        )

        def cosine_decay(x):
            if x < 1:
                return (np.cos(x * np.pi) + 1.0) / 2.0
            else:
                return 0.0

        def warmup_fn(x):
            return LR_SCALE * (0.5 + 0.5 * x)

        def decay_fn(x):
            return LR_SCALE * (DECAY_TARGET + (1 - DECAY_TARGET) * cosine_decay(x))

        DECAY_TARGET = (
            0.01 / LR_SCALE
            if self.config.RL.PPO.ada_scale or True
            else (0.25 / LR_SCALE if self.config.RL.DDPPO.scale_lr else 1.0)
        )
        DECAY_PERCENT = 1.0 if self.config.RL.PPO.ada_scale or True else 0.5
        WARMUP_PERCENT = (
            0.01
            if (self.config.RL.DDPPO.scale_lr and not self.config.RL.PPO.ada_scale)
            else 0.0
        )

        def lr_fn():
            x = self.percent_done()
            if x < WARMUP_PERCENT:
                return warmup_fn(x / WARMUP_PERCENT)
            else:
                return decay_fn((x - WARMUP_PERCENT) / DECAY_PERCENT)

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer, lr_lambda=lambda x: lr_fn()
        )

        interrupted_state = load_interrupted_state(resume_from=self.resume_from)
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])

        self.agent.init_amp(self.config.SIM_BATCH_SIZE)
        self.actor_critic.init_trt(self.config.SIM_BATCH_SIZE)
        self.actor_critic.script_net()
        self.agent.init_distributed(find_unused_params=False)

        if self.world_rank == 0:
            logger.info(
                "agent number of trainable parameters: {}".format(
                    sum(
                        param.numel()
                        for param in self.agent.parameters()
                        if param.requires_grad
                    )
                )
            )

        if self._static_encoder:
            self._encoder = self.actor_critic.net.visual_encoder
            self.observation_space = SpaceDict(
                {
                    "visual_features": spaces.Box(
                        low=np.finfo(np.float32).min,
                        high=np.finfo(np.float32).max,
                        shape=self._encoder.output_shape,
                        dtype=np.float32,
                    ),
                    **self.observation_space,
                }
            )
            with torch.no_grad():
                batch["visual_features"] = self._encoder(batch)

        nenvs = self.config.SIM_BATCH_SIZE
        rollouts = DoubleBufferedRolloutStorage(
            ppo_cfg.num_steps,
            nenvs,
            self.observation_space,
            self.action_space,
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.num_recurrent_layers,
            use_data_aug=ppo_cfg.use_data_aug,
            aug_type=ppo_cfg.aug_type,
            double_buffered=double_buffered,
            vtrace=ppo_cfg.vtrace,
        )
        rollouts.to(self.device)
        rollouts.to_fp16()

        self._warmup(rollouts)

        (
            self.envs,
            self._observations,
            self._rewards,
            self._masks,
            self._rollout_infos,
            self._syncs,
        ) = construct_envs(
            self.config,
            num_worker_groups=self.config.NUM_PARALLEL_SCENES,
            double_buffered=double_buffered,
        )

        def _setup_render_and_populate_initial_frame():
            for idx in range(2 if double_buffered else 1):
                self.envs.reset(idx)

                batch = self._observations[idx]
                self._syncs[idx].wait()

                tree_copy_in_place(
                    tree_select(0, rollouts[idx].storage_buffers["observations"]),
                    batch,
                )

        _setup_render_and_populate_initial_frame()

        current_episode_reward = torch.zeros(nenvs, 1)
        running_episode_stats = dict(
            count=torch.zeros(nenvs, 1,), reward=torch.zeros(nenvs, 1,),
        )

        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size)
        )
        time_per_frame_window = deque(maxlen=ppo_cfg.reward_window_size)

        buffer_ranges = []
        for i in range(2 if double_buffered else 1):
            start_ind = buffer_ranges[-1].stop if i > 0 else 0
            buffer_ranges.append(
                slice(
                    start_ind,
                    start_ind
                    + self.config.SIM_BATCH_SIZE // (2 if double_buffered else 1),
                )
            )

        if interrupted_state is not None:
            requeue_stats = interrupted_state["requeue_stats"]

            self.count_steps = requeue_stats["count_steps"]
            self.update = requeue_stats["start_update"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            prev_time = requeue_stats["prev_time"]
            burn_steps = requeue_stats["burn_steps"]
            burn_time = requeue_stats["burn_time"]

            self.agent.ada_scale.load_state_dict(interrupted_state["ada_scale_state"])

            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            if "amp_state" in interrupted_state:
                apex.amp.load_state_dict(interrupted_state["amp_state"])

            if "grad_scaler_state" in interrupted_state:
                self.agent.grad_scaler.load_state_dict(
                    interrupted_state["grad_scaler_state"]
                )

        with (
            TensorboardWriter(
                self.config.TENSORBOARD_DIR,
                flush_secs=self.flush_secs,
                purge_step=int(self.count_steps),
            )
            if self.world_rank == 0
            else contextlib.suppress()
        ) as writer:
            distrib.barrier()
            t_start = time.time()
            while not self.is_done():
                t_rollout_start = time.time()
                if self.update == BURN_IN_UPDATES:
                    burn_time = t_rollout_start - t_start
                    burn_steps = self.count_steps

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        self.percent_done(), final_decay=ppo_cfg.decay_factor,
                    )

                if (
                    not BPS_BENCHMARK
                    and (REQUEUE.is_set() or ((self.update + 1) % 100) == 0)
                    and self.world_rank == 0
                ):
                    requeue_stats = dict(
                        count_steps=self.count_steps,
                        count_checkpoints=count_checkpoints,
                        start_update=self.update,
                        prev_time=(time.time() - t_start) + prev_time,
                        burn_time=burn_time,
                        burn_steps=burn_steps,
                    )

                    def _cast(param):
                        if "Half" in param.type():
                            param = param.to(dtype=torch.float32)

                        return param

                    save_interrupted_state(
                        dict(
                            state_dict={
                                k: _cast(v) for k, v in self.agent.state_dict().items()
                            },
                            ada_scale_state=self.agent.ada_scale.state_dict(),
                            lr_sched_state=lr_scheduler.state_dict(),
                            config=self.config,
                            requeue_stats=requeue_stats,
                            grad_scaler_state=self.agent.grad_scaler.state_dict(),
                        )
                    )

                if EXIT.is_set():
                    self._observations = None
                    self._rewards = None
                    self._masks = None
                    self._rollout_infos = None
                    self._syncs = None

                    del self.envs
                    self.envs = None

                    requeue_job()
                    return

                self.agent.eval()

                count_steps_delta = self._n_buffered_sampling(
                    rollouts,
                    current_episode_reward,
                    running_episode_stats,
                    buffer_ranges,
                    ppo_cfg.num_steps,
                    num_rollouts_done_store,
                )

                num_rollouts_done_store.add("num_done", 1)

                if not rollouts.vtrace:
                    self._compute_returns(ppo_cfg, rollouts)

                (value_loss, action_loss, dist_entropy) = self._update_agent(rollouts)

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                lr_scheduler.step()

                with self.timing.add_time("Logging"):
                    stats_ordering = list(sorted(running_episode_stats.keys()))
                    stats = torch.stack(
                        [running_episode_stats[k] for k in stats_ordering], 0,
                    ).to(device=self.device)
                    distrib.all_reduce(stats)
                    stats = stats.to(device="cpu")

                    for i, k in enumerate(stats_ordering):
                        window_episode_stats[k].append(stats[i])

                    stats = torch.tensor(
                        [
                            value_loss,
                            action_loss,
                            count_steps_delta,
                            *self.envs.swap_stats,
                        ],
                        device=self.device,
                    )
                    distrib.all_reduce(stats)
                    stats = stats.to(device="cpu")
                    count_steps_delta = int(stats[2].item())
                    self.count_steps += count_steps_delta

                    time_per_frame_window.append(
                        (time.time() - t_rollout_start) / count_steps_delta
                    )

                    if self.world_rank == 0:
                        losses = [
                            stats[0].item() / self.world_size,
                            stats[1].item() / self.world_size,
                        ]
                        deltas = {
                            k: (
                                (v[-1] - v[0]).sum().item()
                                if len(v) > 1
                                else v[0].sum().item()
                            )
                            for k, v in window_episode_stats.items()
                        }
                        deltas["count"] = max(deltas["count"], 1.0)

                        writer.add_scalar(
                            "reward",
                            deltas["reward"] / deltas["count"],
                            self.count_steps,
                        )

                        # Check to see if there are any metrics
                        # that haven't been logged yet
                        metrics = {
                            k: v / deltas["count"]
                            for k, v in deltas.items()
                            if k not in {"reward", "count"}
                        }
                        if len(metrics) > 0:
                            writer.add_scalars("metrics", metrics, self.count_steps)

                        writer.add_scalars(
                            "losses",
                            {k: l for l, k in zip(losses, ["value", "policy"])},
                            self.count_steps,
                        )

                        optim = self.agent.optimizer
                        writer.add_scalar(
                            "optimizer/base_lr",
                            optim.param_groups[-1]["lr"],
                            self.count_steps,
                        )
                        if "gain" in optim.param_groups[-1]:
                            for idx, group in enumerate(optim.param_groups):
                                writer.add_scalar(
                                    f"optimizer/lr_{idx}",
                                    group["lr"] * group["gain"],
                                    self.count_steps,
                                )
                                writer.add_scalar(
                                    f"optimizer/gain_{idx}",
                                    group["gain"],
                                    self.count_steps,
                                )

                        # log stats
                        if (
                            self.update > 0
                            and self.update % self.config.LOG_INTERVAL == 0
                        ):
                            logger.info(
                                "update: {}\twindow fps: {:.3f}\ttotal fps: {:.3f}\tframes: {}".format(
                                    self.update,
                                    1.0
                                    / (
                                        sum(time_per_frame_window)
                                        / len(time_per_frame_window)
                                    ),
                                    (self.count_steps - burn_steps)
                                    / ((time.time() - t_start) + prev_time - burn_time),
                                    self.count_steps,
                                )
                            )

                            logger.info(
                                "swap percent: {:.3f}\tscenes in use: {:.3f}\tenvs per scene: {:.3f}".format(
                                    stats[3].item() / self.world_size,
                                    stats[4].item() / self.world_size,
                                    stats[5].item() / self.world_size,
                                )
                            )

                            logger.info(
                                "Average window size: {}  {}".format(
                                    len(window_episode_stats["count"]),
                                    "  ".join(
                                        "{}: {:.3f}".format(k, v / deltas["count"])
                                        for k, v in deltas.items()
                                        if k != "count"
                                    ),
                                )
                            )

                            logger.info(self.timing)
                            # self.envs.print_renderer_stats()

                        # checkpoint model
                        if self.should_checkpoint():
                            self.save_checkpoint(
                                f"ckpt.{count_checkpoints}.pth",
                                dict(
                                    step=self.count_steps,
                                    wall_clock_time=(
                                        (time.time() - t_start) + prev_time
                                    ),
                                ),
                            )
                            count_checkpoints += 1

                self.update += 1

            self.save_checkpoint(
                "ckpt.done.pth",
                dict(
                    step=self.count_steps,
                    wall_clock_time=((time.time() - t_start) + prev_time),
                ),
            )
            self._observations = None
            self._rewards = None
            self._masks = None
            self._rollout_infos = None
            self._syncs = None
            del self.envs
            self.envs = None
示例#7
0
    def __init__(self, config: Config) -> None:
        spaces = {
            get_default_config().GOAL_SENSOR_UUID: Box(
                low=np.finfo(np.float32).min,
                high=np.finfo(np.float32).max,
                shape=(2,),
                dtype=np.float32,
            )
        }

        if config.INPUT_TYPE in ["depth", "rgbd"]:
            spaces["depth"] = Box(
                low=0,
                high=1,
                shape=(config.RESOLUTION, config.RESOLUTION, 1),
                dtype=np.float32,
            )

        if config.INPUT_TYPE in ["rgb", "rgbd"]:
            spaces["rgb"] = Box(
                low=0,
                high=255,
                shape=(config.RESOLUTION, config.RESOLUTION, 3),
                dtype=np.uint8,
            )
        observation_spaces = SpaceDict(spaces)

        action_spaces = Discrete(4)

        self.device = (
            torch.device("cuda:{}".format(config.PTH_GPU_ID))
            if torch.cuda.is_available()
            else torch.device("cpu")
        )
        self.hidden_size = config.HIDDEN_SIZE

        random.seed(config.RANDOM_SEED)
        torch.random.manual_seed(config.RANDOM_SEED)
        if torch.cuda.is_available():
            torch.backends.cudnn.deterministic = True  # type: ignore

        self.actor_critic = PointNavResNetPolicy(
            observation_space=observation_spaces,
            action_space=action_spaces,
            hidden_size=self.hidden_size,
            normalize_visual_inputs="rgb" in spaces,
        )
        self.actor_critic.to(self.device)

        if config.MODEL_PATH:
            ckpt = torch.load(config.MODEL_PATH, map_location=self.device)
            #  Filter only actor_critic weights
            self.actor_critic.load_state_dict(
                {
                    k[len("actor_critic.") :]: v
                    for k, v in ckpt["state_dict"].items()
                    if "actor_critic" in k
                }
            )

        else:
            habitat.logger.error(
                "Model checkpoint wasn't loaded, evaluating " "a random model."
            )

        self.test_recurrent_hidden_states: Optional[torch.Tensor] = None
        self.not_done_masks: Optional[torch.Tensor] = None
        self.prev_actions: Optional[torch.Tensor] = None
示例#8
0
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        r"""Evaluates a single checkpoint.

        Args:
            checkpoint_path: path of checkpoint
            writer: tensorboard writer object for logging to tensorboard
            checkpoint_index: index of cur checkpoint for logging

        Returns:
            None
        """

        from habitat_baselines.common.environments import get_env_class

        # Map location CPU is almost always better than mapping to a CUDA device.
        ckpt_dict = self.load_checkpoint(checkpoint_path, map_location="cpu")

        if self.config.EVAL.USE_CKPT_CONFIG:
            config = self._setup_eval_config(ckpt_dict["config"])
        else:
            config = self.config.clone()

        ppo_cfg = config.RL.PPO

        config.defrost()
        config.TASK_CONFIG.DATASET.SPLIT = config.EVAL.SPLIT
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.SHUFFLE = False
        config.TASK_CONFIG.ENVIRONMENT.ITERATOR_OPTIONS.MAX_SCENE_REPEAT_STEPS = -1
        config.freeze()

        if len(self.config.VIDEO_OPTION) > 0:
            config.defrost()
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")
            config.TASK_CONFIG.TASK.MEASUREMENTS.append("COLLISIONS")
            config.freeze()

        #  logger.info(f"env config: {config}")
        self.envs = construct_envs_habitat(config,
                                           get_env_class(config.ENV_NAME))
        self.observation_space = SpaceDict({
            "pointgoal_with_gps_compass":
            spaces.Box(low=0.0, high=1.0, shape=(2, ), dtype=np.float32)
        })

        if self.config.COLOR:
            self.observation_space = SpaceDict({
                "rgb":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=(3, *self.config.RESOLUTION),
                    dtype=np.uint8,
                ),
                **self.observation_space.spaces,
            })

        if self.config.DEPTH:
            self.observation_space = SpaceDict({
                "depth":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=(1, *self.config.RESOLUTION),
                    dtype=np.float32,
                ),
                **self.observation_space.spaces,
            })

        self.action_space = self.envs.action_spaces[0]
        self._setup_actor_critic_agent(ppo_cfg)

        self.agent.load_state_dict(ckpt_dict["state_dict"])
        self.actor_critic = self.agent.actor_critic
        self.actor_critic.script_net()
        self.actor_critic.eval()

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)

        test_recurrent_hidden_states = torch.zeros(
            self.config.NUM_PROCESSES,
            self.actor_critic.num_recurrent_layers,
            ppo_cfg.hidden_size,
            device=self.device,
        )
        prev_actions = torch.zeros(self.config.NUM_PROCESSES,
                                   1,
                                   device=self.device,
                                   dtype=torch.long)
        not_done_masks = torch.zeros(self.config.NUM_PROCESSES,
                                     1,
                                     device=self.device,
                                     dtype=torch.bool)
        stats_episodes = dict()  # dict of dicts that stores stats per episode

        rgb_frames = [[] for _ in range(self.config.NUM_PROCESSES)
                      ]  # type: List[List[np.ndarray]]
        if len(self.config.VIDEO_OPTION) > 0:
            os.makedirs(self.config.VIDEO_DIR, exist_ok=True)

        number_of_eval_episodes = self.config.TEST_EPISODE_COUNT
        if number_of_eval_episodes == -1:
            number_of_eval_episodes = sum(self.envs.number_of_episodes)
        else:
            total_num_eps = sum(self.envs.number_of_episodes)
            if total_num_eps < number_of_eval_episodes:
                logger.warn(
                    f"Config specified {number_of_eval_episodes} eval episodes"
                    ", dataset only has {total_num_eps}.")
                logger.warn(f"Evaluating with {total_num_eps} instead.")
                number_of_eval_episodes = total_num_eps

        evals_per_ep = 5
        count_per_ep = defaultdict(lambda: 0)

        pbar = tqdm.tqdm(total=number_of_eval_episodes * evals_per_ep)
        self.actor_critic.eval()
        while (len(stats_episodes) < (number_of_eval_episodes * evals_per_ep)
               and self.envs.num_envs > 0):
            current_episodes = self.envs.current_episodes()

            with torch.no_grad():
                (_, dist_result,
                 test_recurrent_hidden_states) = self.actor_critic.act(
                     batch,
                     test_recurrent_hidden_states,
                     prev_actions,
                     not_done_masks,
                     deterministic=False,
                 )
                actions = dist_result["actions"]

                prev_actions.copy_(actions)
                actions = actions.to("cpu")

            outputs = self.envs.step([a[0].item() for a in actions])

            observations, rewards, dones, infos = [
                list(x) for x in zip(*outputs)
            ]
            batch = batch_obs(observations, device=self.device)

            not_done_masks = torch.tensor(
                [[False] if done else [True] for done in dones],
                dtype=torch.bool,
                device=self.device,
            )

            rewards = torch.tensor(rewards,
                                   dtype=torch.float,
                                   device=self.device).unsqueeze(1)
            current_episode_reward += rewards
            next_episodes = self.envs.current_episodes()
            envs_to_pause = []
            n_envs = self.envs.num_envs
            for i in range(n_envs):
                next_count_key = (
                    next_episodes[i].scene_id,
                    next_episodes[i].episode_id,
                )

                if count_per_ep[next_count_key] == evals_per_ep:
                    envs_to_pause.append(i)

                # episode ended
                if not_done_masks[i].item() == 0:
                    pbar.update()
                    episode_stats = dict()
                    episode_stats["reward"] = current_episode_reward[i].item()
                    episode_stats.update(
                        self._extract_scalars_from_info(infos[i]))
                    current_episode_reward[i] = 0
                    # use scene_id + episode_id as unique id for storing stats
                    count_key = (
                        current_episodes[i].scene_id,
                        current_episodes[i].episode_id,
                    )
                    count_per_ep[count_key] = count_per_ep[count_key] + 1

                    ep_key = (count_key, count_per_ep[count_key])
                    stats_episodes[ep_key] = episode_stats

                    if len(self.config.VIDEO_OPTION) > 0:
                        generate_video(
                            video_option=self.config.VIDEO_OPTION,
                            video_dir=self.config.VIDEO_DIR,
                            images=rgb_frames[i],
                            episode_id=current_episodes[i].episode_id,
                            checkpoint_idx=checkpoint_index,
                            metrics=self._extract_scalars_from_info(infos[i]),
                            tb_writer=writer,
                        )

                        rgb_frames[i] = []

                # episode continues
                elif len(self.config.VIDEO_OPTION) > 0:
                    frame = observations_to_image(observations[i], infos[i])
                    rgb_frames[i].append(frame)

            (
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            ) = self._pause_envs(
                envs_to_pause,
                self.envs,
                test_recurrent_hidden_states,
                not_done_masks,
                current_episode_reward,
                prev_actions,
                batch,
                rgb_frames,
            )

        self.envs.close()
        pbar.close()
        num_episodes = len(stats_episodes)
        aggregated_stats = dict()
        for stat_key in next(iter(stats_episodes.values())).keys():
            values = [
                v[stat_key] for v in stats_episodes.values()
                if v[stat_key] is not None
            ]
            if len(values) > 0:
                aggregated_stats[stat_key] = sum(values) / len(values)
            else:
                aggregated_stats[stat_key] = 0

        for k, v in aggregated_stats.items():
            logger.info(f"Average episode {k}: {v:.4f}")

        step_id = checkpoint_index
        if "extra_state" in ckpt_dict and "step" in ckpt_dict["extra_state"]:
            step_id = ckpt_dict["extra_state"]["step"]

        writer.add_scalars(
            "eval_reward",
            {"average reward": aggregated_stats["reward"]},
            step_id,
        )

        metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"}
        if len(metrics) > 0:
            writer.add_scalars("eval_metrics", metrics, step_id)

        self.num_frames = step_id

        time.sleep(30)
示例#9
0
    def __init__(
        self, config: Config, dataset: Optional[Dataset] = None
    ) -> None:
        """Constructor

        :param config: config for the environment. Should contain id for
            simulator and ``task_name`` which are passed into ``make_sim`` and
            ``make_task``.
        :param dataset: reference to dataset for task instance level
            information. Can be defined as :py:`None` in which case
            ``_episodes`` should be populated from outside.
        """

        assert config.is_frozen(), (
            "Freeze the config before creating the "
            "environment, use config.freeze()."
        )
        self._config = config
        self._dataset = dataset
        self._current_episode_index = None
        self._override_rand_goal = \
            config.ENVIRONMENT.OVERRIDE_RAND_GOAL.ENABLED
        self._override_min_d = config.ENVIRONMENT.OVERRIDE_RAND_GOAL.MIN_DIST
        self._override_radius = config.ENVIRONMENT.OVERRIDE_RAND_GOAL.RADIUS
        self._angle = config.SIMULATOR.TURN_ANGLE

        if self._dataset is None and config.DATASET.TYPE:
            self._dataset = make_dataset(
                id_dataset=config.DATASET.TYPE, config=config.DATASET
            )
        self._episodes = self._dataset.episodes if self._dataset else []
        self._current_episode = None
        iter_option_dict = {
            k.lower(): v
            for k, v in config.ENVIRONMENT.ITERATOR_OPTIONS.items()
        }
        self._episode_iterator = self._dataset.get_episode_iterator(
            **iter_option_dict
        )

        # load the first scene if dataset is present
        if self._dataset:
            assert (
                len(self._dataset.episodes) > 0
            ), "dataset should have non-empty episodes list"
            self._config.defrost()
            self._config.SIMULATOR.SCENE = self._dataset.episodes[0].scene_id
            self._config.freeze()

        self._sim = make_sim(
            id_sim=self._config.SIMULATOR.TYPE, config=self._config.SIMULATOR
        )
        self._task = make_task(
            self._config.TASK.TYPE,
            config=self._config.TASK,
            sim=self._sim,
            dataset=self._dataset,
        )
        self.observation_space = SpaceDict(
            {
                **self._sim.sensor_suite.observation_spaces.spaces,
                **self._task.sensor_suite.observation_spaces.spaces,
            }
        )
        self.action_space = self._task.action_space
        self._max_episode_seconds = (
            self._config.ENVIRONMENT.MAX_EPISODE_SECONDS
        )
        self._max_episode_steps = self._config.ENVIRONMENT.MAX_EPISODE_STEPS
        self._elapsed_steps = 0
        self._episode_start_time: Optional[float] = None
        self._episode_over = False
        self._seed_id = None
示例#10
0
    def __init__(self, config: Config) -> None:
        image_size = config.RL.POLICY.OBS_TRANSFORMS.CENTER_CROPPER
        if "ObjectNav" in config.TASK_CONFIG.TASK.TYPE:
            OBJECT_CATEGORIES_NUM = 20
            spaces = {
                "objectgoal":
                Box(low=0,
                    high=OBJECT_CATEGORIES_NUM,
                    shape=(1, ),
                    dtype=np.int64),
                "compass":
                Box(low=-np.pi, high=np.pi, shape=(1, ), dtype=np.float32),
                "gps":
                Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=(2, ),
                    dtype=np.float32,
                ),
            }
        else:
            spaces = {
                "pointgoal":
                Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=(2, ),
                    dtype=np.float32,
                )
            }

        if config.INPUT_TYPE in ["depth", "rgbd"]:
            spaces["depth"] = Box(
                low=0,
                high=1,
                shape=(image_size.HEIGHT, image_size.WIDTH, 1),
                dtype=np.float32,
            )

        if config.INPUT_TYPE in ["rgb", "rgbd"]:
            spaces["rgb"] = Box(
                low=0,
                high=255,
                shape=(image_size.HEIGHT, image_size.WIDTH, 3),
                dtype=np.uint8,
            )
        observation_spaces = SpaceDict(spaces)
        action_spaces = (Discrete(6) if "ObjectNav"
                         in config.TASK_CONFIG.TASK.TYPE else Discrete(4))
        self.obs_transforms = get_active_obs_transforms(config)
        observation_spaces = apply_obs_transforms_obs_space(
            observation_spaces, self.obs_transforms)

        self.device = (torch.device("cuda:{}".format(config.PTH_GPU_ID))
                       if torch.cuda.is_available() else torch.device("cpu"))
        self.hidden_size = config.RL.PPO.hidden_size

        random.seed(config.RANDOM_SEED)
        np.random.seed(config.RANDOM_SEED)
        _seed_numba(config.RANDOM_SEED)
        torch.random.manual_seed(config.RANDOM_SEED)
        if torch.cuda.is_available():
            torch.backends.cudnn.deterministic = True  # type: ignore

        policy = baseline_registry.get_policy(config.RL.POLICY.name)
        self.actor_critic = policy.from_config(config, observation_spaces,
                                               action_spaces)

        self.actor_critic.to(self.device)

        if config.MODEL_PATH:
            ckpt = torch.load(config.MODEL_PATH, map_location=self.device)
            #  Filter only actor_critic weights
            self.actor_critic.load_state_dict({
                k[len("actor_critic."):]: v
                for k, v in ckpt["state_dict"].items() if "actor_critic" in k
            })

        else:
            habitat.logger.error("Model checkpoint wasn't loaded, evaluating "
                                 "a random model.")

        self.test_recurrent_hidden_states: Optional[torch.Tensor] = None
        self.not_done_masks: Optional[torch.Tensor] = None
        self.prev_actions: Optional[torch.Tensor] = None