コード例 #1
0
ファイル: viz_utils.py プロジェクト: kolbytn/allenact
    def log(
        self,
        log_writer: SummaryWriter,
        task_outputs: Optional[List[Any]],
        render: Optional[Dict[str, List[Dict[str, Any]]]],
        num_steps: int,
    ):
        viz_order, all_episodes = self._auto_viz_order(task_outputs)
        if viz_order is None:
            get_logger().debug("trajectory viz returning without visualizing")
            return

        for page, current_ids in enumerate(viz_order):
            figs = []
            for episode_id in current_ids:
                # assert episode_id in all_episodes
                if episode_id not in all_episodes:
                    get_logger().warning(
                        "skipping viz for missing episode {}".format(
                            episode_id))
                    continue
                figs.append(self.make_fig(all_episodes[episode_id],
                                          episode_id))
            if len(figs) == 0:
                continue
            log_writer.add_figure(
                "{}/{}_group{}".format(self.mode, self.label, page),
                figs,
                global_step=num_steps,
            )
            plt.close(
                "all"
            )  # close all current figures (SummaryWriter already closes all figures we log)
コード例 #2
0
    def process_eval_package(self, log_writer: SummaryWriter,
                             pkg: LoggingPackage):
        training_steps = pkg.training_steps
        checkpoint_file_name = pkg.checkpoint_file_name
        render = pkg.viz_data
        task_outputs = pkg.metric_dicts

        num_tasks = pkg.num_non_empty_metrics_dicts_added
        metric_means = pkg.metrics_tracker.means()

        mode = pkg.mode

        log_writer.add_scalar(f"{mode}/num_tasks_evaled", num_tasks,
                              training_steps)

        message = [f"{mode} {training_steps} steps:"]
        for k in sorted(metric_means.keys()):
            log_writer.add_scalar(f"{mode}/{k}", metric_means[k],
                                  training_steps)
            message.append(f"{k} {metric_means[k]}")
        message.append(f"tasks {num_tasks} checkpoint {checkpoint_file_name}")
        get_logger().info(" ".join(message))

        if self.visualizer is not None:
            self.visualizer.log(
                log_writer=log_writer,
                task_outputs=task_outputs,
                render=render,
                num_steps=training_steps,
            )
コード例 #3
0
    def log(
        self,
        log_writer: SummaryWriter,
        task_outputs: Optional[List[Any]],
        render: Optional[Dict[str, List[Dict[str, Any]]]],
        num_steps: int,
    ):
        if render is None:
            return

        for page, current_ids in enumerate(self.episode_ids):
            figs = []
            for episode_id in current_ids:
                if episode_id not in render or len(render[episode_id]) == 0:
                    get_logger().warning(
                        "skipping viz for missing or 0-length episode {}".
                        format(episode_id))
                    continue
                episode_src = [
                    step[self.datum_id] for step in render[episode_id]
                    if self.datum_id in step
                ]

                figs.append(self.make_fig(episode_src, episode_id))
            if len(figs) == 0:
                continue
            log_writer.add_figure(
                "{}/{}_group{}".format(self.mode, self.label, page),
                figs,
                global_step=num_steps,
            )
            plt.close(
                "all"
            )  # close all current figures (SummaryWriter already closes all figures we log)
コード例 #4
0
    def log(
        self,
        log_writer: SummaryWriter,
        task_outputs: Optional[List[Any]],
        render: Optional[Dict[str, List[Dict[str, Any]]]],
        num_steps: int,
    ):
        if render is None:
            return

        datum_id = self._source_to_str(self.vector_task_sources[0],
                                       is_vector_task=True)
        for page, current_ids in enumerate(self.episode_ids):
            images = []  # list of lists of rgb frames
            for episode_id in current_ids:
                # assert episode_id in render
                if episode_id not in render:
                    get_logger().warning(
                        "skipping viz for missing episode {}".format(
                            episode_id))
                    continue
                # TODO overlay episode id?
                images.append([step[datum_id] for step in render[episode_id]])
            if len(images) == 0:
                continue
            vid = self.make_vid(images)
            if vid is not None:
                log_writer.add_vid(
                    "{}/{}_group{}".format(self.mode, self.label, page),
                    vid,
                    global_step=num_steps,
                )
コード例 #5
0
    def process_train_packages(
        self,
        log_writer: SummaryWriter,
        pkgs: List[LoggingPackage],
        last_steps=0,
        last_offpolicy_steps=0,
        last_time=0.0,
    ):
        assert self.mode == "train"

        current_time = time.time()

        training_steps = pkgs[0].training_steps
        offpolicy_steps = pkgs[0].off_policy_steps
        log_writer.add_scalar(
            tag="train/pipeline_stage",
            scalar_value=pkgs[0].pipeline_stage,
            global_step=training_steps,
        )

        metrics_and_train_info_tracker = ScalarMeanTracker()
        for pkg in pkgs:
            metrics_and_train_info_tracker.add_scalars(
                scalars=pkg.metrics_tracker.means(),
                n=pkg.metrics_tracker.counts())
            metrics_and_train_info_tracker.add_scalars(
                scalars=pkg.train_info_tracker.means(),
                n=pkg.train_info_tracker.counts(),
            )

        message = [
            "train {} steps {} offpolicy:".format(training_steps,
                                                  offpolicy_steps)
        ]
        means = metrics_and_train_info_tracker.means()
        for k in sorted(means.keys(),
                        key=lambda mean_key: ("/" in mean_key, mean_key)):
            if "offpolicy" not in k:
                log_writer.add_scalar("{}/".format(self.mode) + k, means[k],
                                      training_steps)
            else:
                log_writer.add_scalar(k, means[k], training_steps)
            message.append(k + " {:.3g}".format(means[k]))
        message += ["elapsed_time {:.3g}s".format(current_time - last_time)]

        if last_steps > 0:
            fps = (training_steps - last_steps) / (current_time - last_time)
            message += ["approx_fps {:.3g}".format(fps)]
            log_writer.add_scalar("train/approx_fps", fps, training_steps)

        if last_offpolicy_steps > 0:
            fps = (offpolicy_steps - last_offpolicy_steps) / (current_time -
                                                              last_time)
            message += ["offpolicy/approx_fps {:.3g}".format(fps)]
            log_writer.add_scalar("offpolicy/approx_fps", fps, training_steps)

        get_logger().info(" ".join(message))

        return training_steps, offpolicy_steps, current_time
コード例 #6
0
    def process_test_packages(
        self,
        log_writer: SummaryWriter,
        pkgs: List[LoggingPackage],
        all_results: Optional[List[Any]] = None,
    ):
        mode = pkgs[0].mode
        assert mode == "test"

        training_steps = pkgs[0].training_steps

        all_metrics_tracker = ScalarMeanTracker()
        metric_dicts_list, render, checkpoint_file_name = [], {}, []
        for pkg in pkgs:
            all_metrics_tracker.add_scalars(
                scalars=pkg.metrics_tracker.means(),
                n=pkg.metrics_tracker.counts())
            metric_dicts_list.extend(pkg.metric_dicts)
            if pkg.viz_data is not None:
                render.update(pkg.viz_data)
            checkpoint_file_name.append(pkg.checkpoint_file_name)

        assert all_equal(checkpoint_file_name)

        message = [f"{mode} {training_steps} steps:"]

        metric_means = all_metrics_tracker.means()
        for k in sorted(metric_means.keys()):
            log_writer.add_scalar(f"{mode}/{k}", metric_means[k],
                                  training_steps)
            message.append(k + " {:.3g}".format(metric_means[k]))

        if all_results is not None:
            results = copy.deepcopy(metric_means)
            results.update({
                "training_steps": training_steps,
                "tasks": metric_dicts_list
            })
            all_results.append(results)

        num_tasks = sum(
            [pkg.num_non_empty_metrics_dicts_added for pkg in pkgs])
        log_writer.add_scalar(f"{mode}/num_tasks_evaled", num_tasks,
                              training_steps)

        message.append("tasks {} checkpoint {}".format(
            num_tasks, checkpoint_file_name[0]))
        get_logger().info(" ".join(message))

        if self.visualizer is not None:
            self.visualizer.log(
                log_writer=log_writer,
                task_outputs=metric_dicts_list,
                render=render,
                num_steps=training_steps,
            )
コード例 #7
0
ファイル: viz_utils.py プロジェクト: kolbytn/allenact
    def log(
        self,
        log_writer: SummaryWriter,
        task_outputs: Optional[List[Any]],
        render: Optional[Dict[str, List[Dict[str, Any]]]],
        num_steps: int,
    ):
        if render is None:
            return

        if (self.action_names is None and task_outputs is not None
                and len(task_outputs) > 0
                and self.action_names_path is not None):
            self.action_names = list(
                self._access(task_outputs[0], self.action_names_path))

        viz_order, _ = self._auto_viz_order(task_outputs)
        if viz_order is None:
            get_logger().debug("actor viz returning without visualizing")
            return

        for page, current_ids in enumerate(viz_order):
            figs = []
            for episode_id in current_ids:
                # assert episode_id in render
                if episode_id not in render:
                    get_logger().warning(
                        "skipping viz for missing episode {}".format(
                            episode_id))
                    continue
                episode_src = [
                    step["actor_probs"] for step in render[episode_id]
                    if "actor_probs" in step
                ]
                assert len(episode_src) == len(render[episode_id])
                figs.append(self.make_fig(episode_src, episode_id))
            if len(figs) == 0:
                continue
            log_writer.add_figure(
                "{}/{}_group{}".format(self.mode, self.label, page),
                figs,
                global_step=num_steps,
            )
            plt.close(
                "all"
            )  # close all current figures (SummaryWriter already closes all figures we log)
コード例 #8
0
    def log(
        self,
        log_writer: SummaryWriter,
        task_outputs: Optional[List[Any]],
        render: Optional[Dict[str, List[Dict[str, Any]]]],
        num_steps: int,
    ):
        if task_outputs is None:
            return

        all_episodes = {
            self._access(episode, self.path_to_id): episode
            for episode in task_outputs
        }

        for page, current_ids in enumerate(self.episode_ids):
            figs = []
            for episode_id in current_ids:
                # assert episode_id in all_episodes
                if episode_id not in all_episodes:
                    get_logger().warning(
                        "skipping viz for missing episode {}".format(
                            episode_id))
                    continue
                figs.append(self.make_fig(all_episodes[episode_id],
                                          episode_id))
            if len(figs) == 0:
                continue
            log_writer.add_figure(
                "{}/{}_group{}".format(self.mode, self.label, page),
                figs,
                global_step=num_steps,
            )
            plt.close(
                "all"
            )  # close all current figures (SummaryWriter already closes all figures we log)
コード例 #9
0
    def log(
            self,
            start_time_str: str,
            nworkers: int,
            test_steps: Sequence[int] = (),
            metrics_file: Optional[str] = None,
    ):
        finalized = False

        log_writer = SummaryWriter(
            log_dir=self.log_writer_path(start_time_str),
            filename_suffix="__{}_{}".format(self.mode,
                                             self.local_start_time_str),
        )

        # To aggregate/buffer metrics from trainers/testers
        collected: List[LoggingPackage] = []
        last_train_steps = 0
        last_offpolicy_steps = 0
        last_train_time = time.time()
        # test_steps = sorted(test_steps, reverse=True)
        test_results: List[Dict] = []
        unfinished_workers = nworkers

        try:
            while True:
                try:
                    package: Union[LoggingPackage, Union[
                        Tuple[str, Any],
                        Tuple[str, Any,
                              Any]]] = self.queues["results"].get(timeout=1)

                    if isinstance(package, LoggingPackage):
                        pkg_mode = package.mode

                        if pkg_mode == "train":
                            collected.append(package)
                            if len(collected) >= nworkers:

                                collected = sorted(
                                    collected,
                                    key=lambda pkg: (
                                        pkg.training_steps,
                                        pkg.off_policy_steps,
                                    ),
                                )

                                if (
                                        collected[nworkers - 1].training_steps
                                        == collected[0].training_steps
                                        and collected[nworkers -
                                                      1].off_policy_steps
                                        == collected[0].off_policy_steps
                                ):  # ensure nworkers have provided the same num_steps
                                    (
                                        last_train_steps,
                                        last_offpolicy_steps,
                                        last_train_time,
                                    ) = self.process_train_packages(
                                        log_writer=log_writer,
                                        pkgs=collected[:nworkers],
                                        last_steps=last_train_steps,
                                        last_offpolicy_steps=
                                        last_offpolicy_steps,
                                        last_time=last_train_time,
                                    )
                                    collected = collected[nworkers:]
                                elif len(collected) > 2 * nworkers:
                                    get_logger().warning(
                                        "Unable to aggregate train packages from all {} workers"
                                        "after {} packages collected".format(
                                            nworkers, len(collected)))
                        elif pkg_mode == "valid":  # they all come from a single worker
                            if (package.training_steps
                                    is not None):  # no validation samplers
                                self.process_eval_package(
                                    log_writer=log_writer, pkg=package)
                            if (
                                    finalized
                                    and self.queues["checkpoints"].empty()
                            ):  # assume queue is actually empty after trainer finished and no checkpoints in queue
                                break
                        elif pkg_mode == "test":
                            collected.append(package)
                            if len(collected) >= nworkers:
                                collected = sorted(
                                    collected, key=lambda x: x.training_steps
                                )  # sort by num_steps
                                if (
                                        collected[nworkers - 1].training_steps
                                        == collected[0].training_steps
                                ):  # ensure nworkers have provided the same num_steps
                                    self.process_test_packages(
                                        log_writer=log_writer,
                                        pkgs=collected[:nworkers],
                                        all_results=test_results,
                                    )
                                    collected = collected[nworkers:]
                                    with open(metrics_file, "w") as f:
                                        json.dump(test_results,
                                                  f,
                                                  indent=4,
                                                  sort_keys=True)
                                        get_logger().debug(
                                            "Updated {} up to checkpoint {}".
                                            format(
                                                metrics_file,
                                                test_steps[len(test_results) -
                                                           1],
                                            ))
                        else:
                            get_logger().error(
                                f"Runner received unknown package of type {pkg_mode}"
                            )
                    else:
                        pkg_mode = package[0]

                        if pkg_mode == "train_stopped":
                            if package[1] == 0:
                                finalized = True
                                if not self.running_validation:
                                    get_logger().info(
                                        "Terminating runner after trainer done (no validation)"
                                    )
                                    break
                            else:
                                raise Exception(
                                    "Train worker {} abnormally terminated".
                                    format(package[1] - 1))
                        elif pkg_mode == "valid_stopped":
                            raise Exception(
                                "Valid worker {} abnormally terminated".format(
                                    package[1] - 1))
                        elif pkg_mode == "test_stopped":
                            if package[1] == 0:
                                unfinished_workers -= 1
                                if unfinished_workers == 0:
                                    get_logger().info(
                                        "Last tester finished. Terminating")
                                    finalized = True
                                    break
                            else:
                                raise RuntimeError(
                                    "Test worker {} abnormally terminated".
                                    format(package[1] - 1))
                        else:
                            get_logger().error(
                                f"Runner received invalid package tuple {package}"
                            )
                except queue.Empty as _:
                    if all(p.exitcode is not None
                           for p in itertools.chain(*self.processes.values())):
                        break
        except KeyboardInterrupt:
            get_logger().info("KeyboardInterrupt. Terminating runner.")
        except Exception:
            get_logger().error("Encountered Exception. Terminating runner.")
            get_logger().exception(traceback.format_exc())
        finally:
            if finalized:
                get_logger().info("Done")
            if log_writer is not None:
                log_writer.close()
            self.close()
            return test_results