コード例 #1
0
ファイル: runner.py プロジェクト: elliotthwang/allenact
    def log(
            self,
            start_time_str: str,
            nworkers: int,
            test_steps: Sequence[int] = (),
            metrics_file: Optional[str] = None,
    ):
        finalized = False

        log_writer: Optional[SummaryWriter] = None
        if not self.disable_tensorboard:
            log_writer = SummaryWriter(
                log_dir=self.log_writer_path(start_time_str),
                filename_suffix="__{}_{}".format(self.mode,
                                                 self.local_start_time_str),
            )

        # To aggregate/buffer metrics from trainers/testers
        collected: List[LoggingPackage] = []
        last_train_steps = 0
        last_offpolicy_steps = 0
        last_train_time = time.time()
        # test_steps = sorted(test_steps, reverse=True)
        test_results: List[Dict] = []
        unfinished_workers = nworkers

        try:
            while True:
                try:
                    package: Union[LoggingPackage, Union[
                        Tuple[str, Any],
                        Tuple[str, Any,
                              Any]]] = self.queues["results"].get(timeout=1)

                    if isinstance(package, LoggingPackage):
                        pkg_mode = package.mode

                        if pkg_mode == "train":
                            collected.append(package)
                            if len(collected) >= nworkers:

                                collected = sorted(
                                    collected,
                                    key=lambda pkg: (
                                        pkg.training_steps,
                                        pkg.off_policy_steps,
                                    ),
                                )

                                if (
                                        collected[nworkers - 1].training_steps
                                        == collected[0].training_steps
                                        and collected[nworkers -
                                                      1].off_policy_steps
                                        == collected[0].off_policy_steps
                                ):  # ensure nworkers have provided the same num_steps
                                    (
                                        last_train_steps,
                                        last_offpolicy_steps,
                                        last_train_time,
                                    ) = self.process_train_packages(
                                        log_writer=log_writer,
                                        pkgs=collected[:nworkers],
                                        last_steps=last_train_steps,
                                        last_offpolicy_steps=
                                        last_offpolicy_steps,
                                        last_time=last_train_time,
                                    )
                                    collected = collected[nworkers:]
                                elif len(collected) > 2 * nworkers:
                                    get_logger().warning(
                                        "Unable to aggregate train packages from all {} workers"
                                        "after {} packages collected".format(
                                            nworkers, len(collected)))
                        elif pkg_mode == "valid":  # they all come from a single worker
                            if (package.training_steps
                                    is not None):  # no validation samplers
                                self.process_eval_package(
                                    log_writer=log_writer, pkg=package)
                            if (
                                    finalized
                                    and self.queues["checkpoints"].empty()
                            ):  # assume queue is actually empty after trainer finished and no checkpoints in queue
                                break
                        elif pkg_mode == "test":
                            collected.append(package)
                            if len(collected) >= nworkers:
                                collected = sorted(
                                    collected, key=lambda x: x.training_steps
                                )  # sort by num_steps
                                if (
                                        collected[nworkers - 1].training_steps
                                        == collected[0].training_steps
                                ):  # ensure nworkers have provided the same num_steps
                                    self.process_test_packages(
                                        log_writer=log_writer,
                                        pkgs=collected[:nworkers],
                                        all_results=test_results,
                                    )

                                    collected = collected[nworkers:]
                                    with open(metrics_file, "w") as f:
                                        json.dump(
                                            test_results,
                                            f,
                                            indent=4,
                                            sort_keys=True,
                                            cls=NumpyJSONEncoder,
                                        )
                                        get_logger().info(
                                            "Updated {} up to checkpoint {}".
                                            format(
                                                metrics_file,
                                                test_steps[len(test_results) -
                                                           1],
                                            ))
                        else:
                            get_logger().error(
                                f"Runner received unknown package of type {pkg_mode}"
                            )
                    else:
                        pkg_mode = package[0]

                        if pkg_mode == "train_stopped":
                            if package[1] == 0:
                                finalized = True
                                if not self.running_validation:
                                    get_logger().info(
                                        "Terminating runner after trainer done (no validation)"
                                    )
                                    break
                            else:
                                raise Exception(
                                    "Train worker {} abnormally terminated".
                                    format(package[1] - 1))
                        elif pkg_mode == "valid_stopped":
                            raise Exception(
                                "Valid worker {} abnormally terminated".format(
                                    package[1] - 1))
                        elif pkg_mode == "test_stopped":
                            if package[1] == 0:
                                unfinished_workers -= 1
                                if unfinished_workers == 0:
                                    get_logger().info(
                                        "Last tester finished. Terminating")
                                    finalized = True
                                    break
                            else:
                                raise RuntimeError(
                                    "Test worker {} abnormally terminated".
                                    format(package[1] - 1))
                        else:
                            get_logger().error(
                                f"Runner received invalid package tuple {package}"
                            )
                except queue.Empty as _:
                    if all(p.exitcode is not None
                           for p in itertools.chain(*self.processes.values())):
                        break
        except KeyboardInterrupt:
            get_logger().info("KeyboardInterrupt. Terminating runner.")
        except Exception:
            get_logger().error("Encountered Exception. Terminating runner.")
            get_logger().exception(traceback.format_exc())
        finally:
            if finalized:
                get_logger().info("Done")
            if log_writer is not None:
                log_writer.close()
            self.close()
            return test_results
コード例 #2
0
ファイル: viz_utils.py プロジェクト: apoorvkh/allenact
    def make_tensorboard_summary(self):
        all_experiments = list(self.experiment_to_train_events_paths_map.keys())

        for experiment_name in all_experiments:
            summary_writer = SummaryWriter(
                os.path.join(self.tensorboard_output_summary_folder, experiment_name)
            )

            test_labels = (
                sorted(list(self.test_data[experiment_name].keys()))
                if len(self.test_data) > 0
                else []
            )
            for test_label in test_labels:
                train_label = test_label.replace("valid", "test").replace(
                    "test", "train"
                )
                if train_label not in self.train_data[experiment_name]:
                    print(
                        f"Missing matching 'train' label {train_label} for eval label {test_label}. Skipping"
                    )
                    continue
                train_data = self.train_data[experiment_name][train_label]
                test_data = self.test_data[experiment_name][test_label]
                scores, times, steps = self._eval_vs_train_time_steps(
                    test_data, train_data
                )
                for score, t, step in zip(scores, times, steps):
                    summary_writer.add_scalar(
                        test_label, score, global_step=step, walltime=t
                    )

            valid_labels = sorted(
                [
                    key
                    for key in list(self.train_data[experiment_name].keys())
                    if "valid" in key
                ]
            )
            for valid_label in valid_labels:
                train_label = valid_label.replace("valid", "train")
                assert (
                    train_label in self.train_data[experiment_name]
                ), f"Missing matching 'train' label {train_label} for valid label {valid_label}"
                train_data = self.train_data[experiment_name][train_label]
                valid_data = self.train_data[experiment_name][valid_label]
                scores, times, steps = self._eval_vs_train_time_steps(
                    valid_data, train_data
                )
                for score, t, step in zip(scores, times, steps):
                    summary_writer.add_scalar(
                        valid_label, score, global_step=step, walltime=t
                    )

            train_labels = sorted(
                [
                    key
                    for key in list(self.train_data[experiment_name].keys())
                    if "train" in key
                ]
            )
            for train_label in train_labels:
                scores, times, steps = self._train_vs_time_steps(
                    self.train_data[experiment_name][train_label]
                )
                for score, t, step in zip(scores, times, steps):
                    summary_writer.add_scalar(
                        train_label, score, global_step=step, walltime=t
                    )

            summary_writer.close()