Exemplo n.º 1
0
    def log_losses(self, model_to_update="G"):
        """Logs metrics on comet.ml

        Args:
            model_to_update (str, optional): One of "G", "D" or "C". Defaults to "G".
        """
        if self.opts.train.log_level < 1:
            return

        if self.exp is None:
            return

        assert model_to_update in {
            "G",
            "D",
            "C",
        }, "unknown model to log losses {}".format(model_to_update)

        losses = self.logger.losses.copy()
        if self.opts.train.log_level == 1:
            # Only log aggregated losses: delete other keys in losses
            for k in self.logger.losses:
                if k not in {"representation", "generator", "translation"}:
                    del losses[k]
        # convert losses into a single-level dictionnary
        losses = flatten_opts(losses)
        self.exp.log_metrics(
            losses, prefix=model_to_update, step=self.logger.global_step
        )
Exemplo n.º 2
0
    def log_losses(self, model_to_update="G", mode="train"):
        """Logs metrics on comet.ml

        Args:
            model_to_update (str, optional): One of "G", "D" or "C". Defaults to "G".
        """
        loss_names = {
            "G": "generator",
            "D": "discriminator",
            "C": "classifier"
        }

        if self.opts.train.log_level < 1:
            return

        if self.exp is None:
            return

        assert model_to_update in {
            "G",
            "D",
            "C",
        }, "unknown model to log losses {}".format(model_to_update)

        loss_to_update = self.logger.losses[loss_names[model_to_update]]

        losses = loss_to_update.copy()

        if self.opts.train.log_level == 1:
            # Only log aggregated losses: delete other keys in losses
            for k in loss_to_update:
                if k not in {"masker", "total_loss", "painter"}:
                    del losses[k]
        # convert losses into a single-level dictionnary

        losses = flatten_opts(losses)
        self.exp.log_metrics(losses,
                             prefix=f"{model_to_update}_{mode}",
                             step=self.logger.global_step)
Exemplo n.º 3
0
def main(opts):
    # -----------------------------
    # -----  Parse arguments  -----
    # -----------------------------

    opts = Dict(OmegaConf.to_container(opts))
    args = opts.args

    # -----------------------
    # -----  Load opts  -----
    # -----------------------

    opts = load_opts(args.config, default=opts)
    if args.resume:
        opts.train.resume = True
    opts.output_path = env_to_path(opts.output_path)

    if not opts.train.resume:
        opts.output_path = get_increased_path(opts.output_path)
    pprint("Running model in", opts.output_path)

    exp = None
    if not args.dev:
        # -------------------------------
        # -----  Check output_path  -----
        # -------------------------------
        if opts.train.resume:
            Path(opts.output_path).mkdir(exist_ok=True)
        else:
            assert not Path(opts.output_path).exists()
            Path(opts.output_path).mkdir()

        # Save config file
        # TODO what if resuming? re-dump?
        with Path(opts.output_path / "opts.yaml").open("w") as f:
            yaml.safe_dump(opts.to_dict())

        if not args.no_comet:
            # ----------------------------------
            # -----  Set Comet Experiment  -----
            # ----------------------------------
            exp = Experiment(project_name="omnigan", auto_metric_logging=False)
            exp.log_parameters(flatten_opts(opts))
            if args.note:
                exp.log_parameter("note", args.note)
            with open(Path(opts.output_path) / "comet_url.txt", "w") as f:
                f.write(exp.url)
    else:
        # ----------------------
        # -----  Dev Mode  -----
        # ----------------------
        pprint("> /!\ Development mode ON")
        print("Cropping data to 32")
        opts.data.transforms += [
            Dict({
                "name": "crop",
                "ignore": False,
                "height": 32,
                "width": 32
            })
        ]

    # -------------------
    # -----  Train  -----
    # -------------------
    trainer = Trainer(opts, comet_exp=exp)
    trainer.logger.time.start_time = time()
    trainer.setup()
    trainer.train()

    # -----------------------------
    # -----  End of training  -----
    # -----------------------------

    pprint("Done training")
Exemplo n.º 4
0
    print_header("test_get_increased_path")
    uid = str(uuid.uuid4())
    p = Path() / uid
    p.mkdir()
    get_increased_path(p).mkdir()
    get_increased_path(p).mkdir()
    get_increased_path(p).mkdir()
    paths = {str(d) for d in Path().glob(uid + "*")}
    target = {str(p), str(p) + " (1)", str(p) + " (2)", str(p) + " (3)"}
    assert paths == target
    print("ok.")
    for d in Path().glob(uid + "*"):
        d.rmdir()

    # ----------------------------------
    # -----  Testing flatten_opts  -----
    # ----------------------------------
    print_header("test_flatten_opts")
    d = addict.Dict()
    d.a.b.c = 2
    d.a.b.d = 3
    d.a.e = 4
    d.f = 5
    assert flatten_opts(d) == {
        "a.b.c": 2,
        "a.b.d": 3,
        "a.e": 4,
        "f": 5,
    }
    print("ok.")