Пример #1
0
 def print_num_parameters(self):
     print("---------------------------")
     if self.G.encoder is not None:
         print("num params encoder: ", get_num_params(self.G.encoder))
     for d in self.G.decoders.keys():
         print("num params decoder {}: {}".format(
             d, get_num_params(self.G.decoders[d])))
     for d in self.D.keys():
         print("num params discrim {}: {}".format(d,
                                                  get_num_params(
                                                      self.D[d])))
     print("num params painter: ", get_num_params(self.G.painter))
     if self.C is not None:
         print("num params classif: ", get_num_params(self.C))
     print("---------------------------")
Пример #2
0
    def resume(self):
        # load_path = self.get_latest_ckpt()
        if "m" in self.opts.tasks and "p" in self.opts.tasks:
            m_path = self.opts.load_paths.m
            p_path = self.opts.load_paths.p

            if m_path == "none":
                m_path = self.opts.output_path
            if p_path == "none":
                p_path = self.opts.output_path

            # Merge the dicts
            m_ckpt_path = Path(m_path) / Path("checkpoints/latest_ckpt.pth")
            p_ckpt_path = Path(p_path) / Path("checkpoints/latest_ckpt.pth")

            m_checkpoint = torch.load(m_ckpt_path)
            p_checkpoint = torch.load(p_ckpt_path)

            checkpoint = merge(m_checkpoint, p_checkpoint)
            print(f"Resuming model from {m_ckpt_path} and {p_ckpt_path}")
        else:
            load_path = Path(
                self.opts.output_path) / Path("checkpoints/latest_ckpt.pth")
            checkpoint = torch.load(load_path)
            print(f"Resuming model from {load_path}")

        self.G.load_state_dict(checkpoint["G"])
        if not ("m" in self.opts.tasks and "p" in self.opts.tasks):
            self.g_opt.load_state_dict(checkpoint["g_opt"])
        self.logger.epoch = checkpoint["epoch"]
        self.logger.global_step = checkpoint["step"]
        # Round step to even number for extraGradient
        if self.logger.global_step % 2 != 0:
            self.logger.global_step += 1

        if self.C is not None and get_num_params(self.C) > 0:
            self.C.load_state_dict(checkpoint["C"])
            if not ("m" in self.opts.tasks and "p" in self.opts.tasks):
                self.c_opt.load_state_dict(checkpoint["c_opt"])

        if self.D is not None and get_num_params(self.D) > 0:
            self.D.load_state_dict(checkpoint["D"])
            if not ("m" in self.opts.tasks and "p" in self.opts.tasks):
                self.d_opt.load_state_dict(checkpoint["d_opt"])
Пример #3
0
    def resume(self):
        # load_path = self.get_latest_ckpt()
        load_path = Path(
            self.opts.output_path) / Path("checkpoints/latest_ckpt.pth")
        checkpoint = torch.load(load_path)
        print(f"Resuming model from {load_path}")
        self.G.load_state_dict(checkpoint["G"])
        self.g_opt.load_state_dict(checkpoint["g_opt"])
        self.logger.epoch = checkpoint["epoch"]
        self.logger.global_step = checkpoint["step"]
        # Round step to even number for extraGradient
        if self.logger.global_step % 2 != 0:
            self.logger.global_step += 1

        if self.C is not None and get_num_params(self.C) > 0:
            self.C.load_state_dict(checkpoint["C"])
            self.c_opt.load_state_dict(checkpoint["c_opt"])

        if self.D is not None and get_num_params(self.D) > 0:
            self.D.load_state_dict(checkpoint["D"])
            self.d_opt.load_state_dict(checkpoint["d_opt"])
Пример #4
0
    def save(self):
        save_dir = Path(self.opts.output_path) / Path("checkpoints")
        save_dir.mkdir(exist_ok=True)
        save_path = Path("latest_ckpt.pth")
        save_path = save_dir / save_path

        # Construct relevant state dicts / optims:
        # Save at least G
        save_dict = {
            "epoch": self.logger.epoch,
            "G": self.G.state_dict(),
            "g_opt": self.g_opt.state_dict(),
            "step": self.logger.global_step,
        }

        if self.C is not None and get_num_params(self.C) > 0:
            save_dict["C"] = self.C.state_dict()
            save_dict["c_opt"] = self.c_opt.state_dict()
        if self.D is not None and get_num_params(self.D) > 0:
            save_dict["D"] = self.D.state_dict()
            save_dict["d_opt"] = self.d_opt.state_dict()

        torch.save(save_dict, save_path)
Пример #5
0
    def setup(self):
        """Prepare the trainer before it can be used to train the models:
            * initialize G and D
            * compute latent space dims and create classifier accordingly
            * creates 3 optimizers
        """
        self.logger.global_step = 0
        start_time = time()
        self.logger.time.start_time = start_time

        self.loaders = get_all_loaders(self.opts)

        self.G: OmniGenerator = get_gen(self.opts,
                                        verbose=self.verbose).to(self.device)
        if self.G.encoder is not None:
            self.latent_shape = self.compute_latent_shape()
        self.input_shape = self.compute_input_shape()
        self.painter_z_h = self.input_shape[-2] // (2**
                                                    self.opts.gen.p.spade_n_up)
        self.painter_z_w = self.input_shape[-1] // (2**
                                                    self.opts.gen.p.spade_n_up)
        self.D: OmniDiscriminator = get_dis(
            self.opts, verbose=self.verbose).to(self.device)
        self.C: OmniClassifier = None
        if self.G.encoder is not None and self.opts.train.latent_domain_adaptation:
            self.C = get_classifier(self.opts,
                                    self.latent_shape,
                                    verbose=self.verbose).to(self.device)
        self.print_num_parameters()

        self.g_opt, self.g_scheduler = get_optimizer(self.G, self.opts.gen.opt)

        if get_num_params(self.D) > 0:
            self.d_opt, self.d_scheduler = get_optimizer(
                self.D, self.opts.dis.opt)
        else:
            self.d_opt, self.d_scheduler = None, None

        if self.C is not None:
            self.c_opt, self.c_scheduler = get_optimizer(
                self.C, self.opts.classifier.opt)
        else:
            self.c_opt, self.c_scheduler = None, None

        if self.opts.train.resume:
            self.resume()

        self.losses = get_losses(self.opts, self.verbose, device=self.device)

        if self.verbose > 0:
            for mode, mode_dict in self.loaders.items():
                for domain, domain_loader in mode_dict.items():
                    print("Loader {} {} : {}".format(
                        mode, domain, len(domain_loader.dataset)))

        # Create display images:
        print("Creating display images...", end="", flush=True)

        if type(self.opts.comet.display_size) == int:
            display_indices = range(self.opts.comet.display_size)
        else:
            display_indices = self.opts.comet.display_size

        self.display_images = {}
        for mode, mode_dict in self.loaders.items():
            self.display_images[mode] = {}
            for domain, domain_loader in mode_dict.items():

                self.display_images[mode][domain] = [
                    Dict(self.loaders[mode][domain].dataset[i])
                    for i in display_indices
                    if i < len(self.loaders[mode][domain].dataset)
                ]

        self.is_setup = True
Пример #6
0
    G = get_gen(opts).to(device)

    # -------------------------------
    # -----  Test Architecture  -----
    # -------------------------------
    if print_architecture:
        print(G)
        # print("DECODERS:", G.decoders)
        # print("ENCODER:", G.encoder)

    # ------------------------------------
    # -----  Test encoder.forward()  -----
    # ------------------------------------
    if test_encoder:
        print_header("test_encoder")
        num_params = get_num_params(G.encoder)
        print("Number of parameters in encoder : {}".format(num_params))
        encoded = G.encode(image)
        print("Latent space dims {}".format(tuple(encoded.shape)[1:]))

    # -------------------------------------------------------
    # -----  Test encode then decode with all decoders  -----
    # -------------------------------------------------------
    if test_encode_decode:
        print_header("test_encode_decode")
        z = G.encode(image)
        for dec in "adhtw":
            if dec in G.decoders:
                if dec == "t":
                    continue
                if dec == "a":