コード例 #1
0
    def setup(self):
        """Prepare the trainer before it can be used to train the models:
            * initialize G and D
            * compute latent space dims and create classifier accordingly
            * creates 3 optimizers
        """
        self.logger.global_step = 0
        start_time = time()
        self.logger.time.start_time = start_time

        self.loaders = get_all_loaders(self.opts)

        self.G = get_gen(self.opts, verbose=self.verbose).to(self.device)
        self.latent_shape = self.compute_latent_shape()
        self.output_size = self.latent_shape[0] * 2 ** self.opts.gen.t.spade_n_up
        self.G.set_translation_decoder(self.latent_shape, self.device)
        self.D = get_dis(self.opts, verbose=self.verbose).to(self.device)
        self.C = get_classifier(self.opts, self.latent_shape, verbose=self.verbose).to(
            self.device
        )
        self.P = {"s": get_mega_model()}  # P => pseudo labeling models

        self.g_opt, self.g_scheduler = get_optimizer(self.G, self.opts.gen.opt)
        self.d_opt, self.d_scheduler = get_optimizer(self.D, self.opts.dis.opt)
        self.c_opt, self.c_scheduler = get_optimizer(self.C, self.opts.classifier.opt)

        self.set_losses()

        if self.verbose > 0:
            for mode, mode_dict in self.loaders.items():
                for domain, domain_loader in mode_dict.items():
                    print(
                        "Loader {} {} : {}".format(
                            mode, domain, len(domain_loader.dataset)
                        )
                    )

        self.is_setup = True
コード例 #2
0
ファイル: trainer.py プロジェクト: fagan2888/omnigan
    def setup(self):
        """Prepare the trainer before it can be used to train the models:
            * initialize G and D
            * compute latent space dims and create classifier accordingly
            * creates 3 optimizers
        """
        self.logger.global_step = 0
        start_time = time()
        self.logger.time.start_time = start_time

        self.loaders = get_all_loaders(self.opts)

        self.G: OmniGenerator = get_gen(self.opts,
                                        verbose=self.verbose).to(self.device)
        if self.G.encoder is not None:
            self.latent_shape = self.compute_latent_shape()
        self.input_shape = self.compute_input_shape()
        self.painter_z_h = self.input_shape[-2] // (2**
                                                    self.opts.gen.p.spade_n_up)
        self.painter_z_w = self.input_shape[-1] // (2**
                                                    self.opts.gen.p.spade_n_up)
        self.D: OmniDiscriminator = get_dis(
            self.opts, verbose=self.verbose).to(self.device)
        self.C: OmniClassifier = None
        if self.G.encoder is not None and self.opts.train.latent_domain_adaptation:
            self.C = get_classifier(self.opts,
                                    self.latent_shape,
                                    verbose=self.verbose).to(self.device)
        self.print_num_parameters()

        self.g_opt, self.g_scheduler = get_optimizer(self.G, self.opts.gen.opt)

        if get_num_params(self.D) > 0:
            self.d_opt, self.d_scheduler = get_optimizer(
                self.D, self.opts.dis.opt)
        else:
            self.d_opt, self.d_scheduler = None, None

        if self.C is not None:
            self.c_opt, self.c_scheduler = get_optimizer(
                self.C, self.opts.classifier.opt)
        else:
            self.c_opt, self.c_scheduler = None, None

        if self.opts.train.resume:
            self.resume()

        self.losses = get_losses(self.opts, self.verbose, device=self.device)

        if self.verbose > 0:
            for mode, mode_dict in self.loaders.items():
                for domain, domain_loader in mode_dict.items():
                    print("Loader {} {} : {}".format(
                        mode, domain, len(domain_loader.dataset)))

        # Create display images:
        print("Creating display images...", end="", flush=True)

        if type(self.opts.comet.display_size) == int:
            display_indices = range(self.opts.comet.display_size)
        else:
            display_indices = self.opts.comet.display_size

        self.display_images = {}
        for mode, mode_dict in self.loaders.items():
            self.display_images[mode] = {}
            for domain, domain_loader in mode_dict.items():

                self.display_images[mode][domain] = [
                    Dict(self.loaders[mode][domain].dataset[i])
                    for i in display_indices
                    if i < len(self.loaders[mode][domain].dataset)
                ]

        self.is_setup = True
コード例 #3
0
ファイル: test_losses.py プロジェクト: tianyu-z/omnigan
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", default="config/local_tests.yaml")
args = parser.parse_args()
root = Path(__file__).parent.parent
opts = load_test_opts(args.config)

if __name__ == "__main__":
    # ------------------------
    # -----  Test Setup  -----
    # ------------------------
    opts.data.loaders.batch_size = 2
    opts.data.loaders.num_workers = 2
    opts.data.loaders.shuffle = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loaders = get_all_loaders(opts)
    batch = next(iter(loaders["train"]["rn"]))
    image = torch.randn(opts.data.loaders.batch_size, 3, 32, 32).to(device)
    G = get_gen(opts).to(device)
    z = G.encode(image)

    # -----------------------------------
    # -----  Test cross_entropy_2d  -----
    # -----------------------------------
    print_header("test_crossentroy_2d")
    prediction = G.decoders["s"](z)
    pce = PixelCrossEntropy()
    print(pce(prediction, batch["data"]["s"].to(device)))
    # ! error how to infer from cropped data: input: 224 output: 256??

    # TODO more test for the losses