Ejemplo n.º 1
0
                    v = int(v)
            if k in config:
                print("Overwriting {:20} {:30} -> {:}".format(k, config[k], v))
                config[k] = v
        except Exception as e:
            print(e)
            print("Ignoring argument", u)

    for o in dir(opts):
        if not o.startswith("_"):
            if o in config:
                print("Overwriting {:20} {:30} -> {:}".format(
                    o, config[k], getattr(opts, o)))
                config[o] = getattr(opts, o)

comet_exp.log_asset(opts.config)
max_iter = config["max_iter"]
display_size = config["display_size"]
config["vgg_model_path"] = opts.output_path

comet_exp.log_parameters(config)

print("Using model", opts.trainer)
# Setup model and data loader
if opts.trainer == "MUNIT":
    trainer = MUNIT_Trainer(config, comet_exp)
elif opts.trainer == "UNIT":
    trainer = UNIT_Trainer(config)
elif opts.trainer == "DoubleMUNIT":
    trainer = DoubleMUNIT_Trainer(config, comet_exp)
else:
Ejemplo n.º 2
0
class CometLogger(Logger):
    def __init__(
        self,
        batch_size: int,
        snapshot_dir: Optional[str] = None,
        snapshot_mode: str = "last",
        snapshot_gap: int = 1,
        exp_set: Optional[str] = None,
        use_print_exp: bool = False,
        saved_exp: Optional[str] = None,
        **kwargs,
    ):
        """
        :param kwargs: passed to comet's Experiment at init.
        """
        if use_print_exp:
            self.experiment = PrintExperiment()
        else:
            from comet_ml import Experiment, ExistingExperiment, OfflineExperiment

            if saved_exp:
                self.experiment = ExistingExperiment(
                    previous_experiment=saved_exp, **kwargs
                )
            else:
                try:
                    self.experiment = Experiment(**kwargs)
                except ValueError:  # no API key
                    log_dir = Path.home() / "logs"
                    log_dir.mkdir(exist_ok=True)
                    self.experiment = OfflineExperiment(offline_directory=str(log_dir))

        self.experiment.log_parameter("complete", False)
        if exp_set:
            self.experiment.log_parameter("exp_set", exp_set)
        if snapshot_dir:
            snapshot_dir = Path(snapshot_dir) / self.experiment.get_key()
        # log_traj_window (int): How many trajectories to hold in deque for computing performance statistics.
        self.log_traj_window = 100
        self._cum_metrics = {
            "n_unsafe_actions": 0,
            "constraint_used": 0,
            "cum_completed_trajs": 0,
            "logging_time": 0,
        }
        self._new_completed_trajs = 0
        self._last_step = 0
        self._start_time = self._last_time = time()
        self._last_snapshot_upload = 0
        self._snaphot_upload_time = 30 * 60

        super().__init__(batch_size, snapshot_dir, snapshot_mode, snapshot_gap)

    def log_fast(
        self,
        step: int,
        traj_infos: Sequence[Dict[str, float]],
        opt_info: Optional[Tuple[Sequence[float], ...]] = None,
        test: bool = False,
    ) -> None:
        if not traj_infos:
            return
        start = time()

        self._new_completed_trajs += len(traj_infos)
        self._cum_metrics["cum_completed_trajs"] += len(traj_infos)
        # TODO: do we need to support sum(t[k]) if key in k?
        # without that, this doesn't include anything from extra eval samplers
        for key in self._cum_metrics:
            if key == "cum_completed_trajs":
                continue
            self._cum_metrics[key] += sum(t.get(key, 0) for t in traj_infos)
        self._cum_metrics["logging_time"] += time() - start

    def log(
        self,
        step: int,
        traj_infos: Sequence[Dict[str, float]],
        opt_info: Optional[Tuple[Sequence[float], ...]] = None,
        test: bool = False,
    ):
        self.log_fast(step, traj_infos, opt_info, test)
        start = time()
        with (self.experiment.test() if test else nullcontext()):
            step *= self.batch_size
            if opt_info is not None:
                # grad norm is left on the GPU for some reason
                # https://github.com/astooke/rlpyt/issues/163
                self.experiment.log_metrics(
                    {
                        k: np.mean(v)
                        for k, v in opt_info._asdict().items()
                        if k != "gradNorm"
                    },
                    step=step,
                )

            if traj_infos:
                agg_vals = {}
                for key in traj_infos[0].keys():
                    if key in self._cum_metrics:
                        continue
                    agg_vals[key] = sum(t[key] for t in traj_infos) / len(traj_infos)
                self.experiment.log_metrics(agg_vals, step=step)

            if not test:
                now = time()
                self.experiment.log_metrics(
                    {
                        "new_completed_trajs": self._new_completed_trajs,
                        "steps_per_second": (step - self._last_step)
                        / (now - self._last_time),
                    },
                    step=step,
                )
                self._last_time = now
                self._last_step = step
                self._new_completed_trajs = 0

        self.experiment.log_metrics(self._cum_metrics, step=step)
        self._cum_metrics["logging_time"] += time() - start

    def log_metric(self, name, val):
        self.experiment.log_metric(name, val)

    def log_parameters(self, parameters):
        self.experiment.log_parameters(parameters)

    def log_config(self, config):
        self.experiment.log_parameter("config", json.dumps(convert_dict(config)))

    def upload_snapshot(self):
        if self.snapshot_dir:
            self.experiment.log_asset(self._previous_snapshot_fname)

    def save_itr_params(
        self, step: int, params: Dict[str, Any], metric: Optional[float] = None
    ) -> None:
        super().save_itr_params(step, params, metric)
        now = time()
        if now - self._last_snapshot_upload > self._snaphot_upload_time:
            self._last_snapshot_upload = now
            self.upload_snapshot()

    def shutdown(self, error: bool = False) -> None:
        if not error:
            self.upload_snapshot()
            self.experiment.log_parameter("complete", True)
        self.experiment.end()