Example #1
0
    def log(
        self,
        log_writer: SummaryWriter,
        task_outputs: Optional[List[Any]],
        render: Optional[Dict[str, List[Dict[str, Any]]]],
        num_steps: int,
    ):
        viz_order, all_episodes = self._auto_viz_order(task_outputs)
        if viz_order is None:
            get_logger().debug("trajectory viz returning without visualizing")
            return

        for page, current_ids in enumerate(viz_order):
            figs = []
            for episode_id in current_ids:
                # assert episode_id in all_episodes
                if episode_id not in all_episodes:
                    get_logger().warning(
                        "skipping viz for missing episode {}".format(episode_id)
                    )
                    continue
                figs.append(self.make_fig(all_episodes[episode_id], episode_id))
            if len(figs) == 0:
                continue
            log_writer.add_figure(
                "{}/{}_group{}".format(self.mode, self.label, page),
                figs,
                global_step=num_steps,
            )
            plt.close(
                "all"
            )  # close all current figures (SummaryWriter already closes all figures we log)
Example #2
0
    def log(
        self,
        log_writer: SummaryWriter,
        task_outputs: Optional[List[Any]],
        render: Optional[Dict[str, List[Dict[str, Any]]]],
        num_steps: int,
    ):
        if render is None:
            return

        if (
            self.action_names is None
            and task_outputs is not None
            and len(task_outputs) > 0
            and self.action_names_path is not None
        ):
            self.action_names = list(
                self._access(task_outputs[0], self.action_names_path)
            )

        viz_order, _ = self._auto_viz_order(task_outputs)
        if viz_order is None:
            get_logger().debug("actor viz returning without visualizing")
            return

        for page, current_ids in enumerate(viz_order):
            figs = []
            for episode_id in current_ids:
                # assert episode_id in render
                if episode_id not in render:
                    get_logger().warning(
                        "skipping viz for missing episode {}".format(episode_id)
                    )
                    continue
                episode_src = [
                    step["actor_probs"]
                    for step in render[episode_id]
                    if "actor_probs" in step
                ]
                assert len(episode_src) == len(render[episode_id])
                figs.append(self.make_fig(episode_src, episode_id))
            if len(figs) == 0:
                continue
            log_writer.add_figure(
                "{}/{}_group{}".format(self.mode, self.label, page),
                figs,
                global_step=num_steps,
            )
            plt.close(
                "all"
            )  # close all current figures (SummaryWriter already closes all figures we log)
Example #3
0
    def log(
        self,
        log_writer: SummaryWriter,
        task_outputs: Optional[List[Any]],
        render: Optional[Dict[str, List[Dict[str, Any]]]],
        num_steps: int,
    ):
        if render is None:
            return

        viz_order, _ = self._auto_viz_order(task_outputs)
        if viz_order is None:
            get_logger().debug("tensor viz returning without visualizing")
            return

        for page, current_ids in enumerate(viz_order):
            figs = []
            for episode_id in current_ids:
                if episode_id not in render or len(render[episode_id]) == 0:
                    get_logger().warning(
                        "skipping viz for missing or 0-length episode {}".format(
                            episode_id
                        )
                    )
                    continue
                episode_src = [
                    step[self.datum_id]
                    for step in render[episode_id]
                    if self.datum_id in step
                ]
                if len(episode_src) > 0:
                    # If the last episode for an inference worker is of length 1, there's no captured rollout sources
                    figs.append(self.make_fig(episode_src, episode_id))
            if len(figs) == 0:
                continue
            log_writer.add_figure(
                "{}/{}_group{}".format(self.mode, self.label, page),
                figs,
                global_step=num_steps,
            )
            plt.close(
                "all"
            )  # close all current figures (SummaryWriter already closes all figures we log)