예제 #1
0
    def setup(self):
        """Prepare the trainer before it can be used to train the models:
            * initialize G and D
            * compute latent space dims and create classifier accordingly
            * creates 3 optimizers
        """
        self.logger.global_step = 0
        start_time = time()
        self.logger.time.start_time = start_time

        self.loaders = get_all_loaders(self.opts)

        self.G = get_gen(self.opts, verbose=self.verbose).to(self.device)
        self.latent_shape = self.compute_latent_shape()
        self.output_size = self.latent_shape[0] * 2 ** self.opts.gen.t.spade_n_up
        self.G.set_translation_decoder(self.latent_shape, self.device)
        self.D = get_dis(self.opts, verbose=self.verbose).to(self.device)
        self.C = get_classifier(self.opts, self.latent_shape, verbose=self.verbose).to(
            self.device
        )
        self.P = {"s": get_mega_model()}  # P => pseudo labeling models

        self.g_opt, self.g_scheduler = get_optimizer(self.G, self.opts.gen.opt)
        self.d_opt, self.d_scheduler = get_optimizer(self.D, self.opts.dis.opt)
        self.c_opt, self.c_scheduler = get_optimizer(self.C, self.opts.classifier.opt)

        self.set_losses()

        if self.verbose > 0:
            for mode, mode_dict in self.loaders.items():
                for domain, domain_loader in mode_dict.items():
                    print(
                        "Loader {} {} : {}".format(
                            mode, domain, len(domain_loader.dataset)
                        )
                    )

        self.is_setup = True
예제 #2
0
    def setup(self):
        """Prepare the trainer before it can be used to train the models:
            * initialize G and D
            * compute latent space dims and create classifier accordingly
            * creates 3 optimizers
        """
        self.logger.global_step = 0
        start_time = time()
        self.logger.time.start_time = start_time

        self.loaders = get_all_loaders(self.opts)

        self.G: OmniGenerator = get_gen(self.opts,
                                        verbose=self.verbose).to(self.device)
        if self.G.encoder is not None:
            self.latent_shape = self.compute_latent_shape()
        self.input_shape = self.compute_input_shape()
        self.painter_z_h = self.input_shape[-2] // (2**
                                                    self.opts.gen.p.spade_n_up)
        self.painter_z_w = self.input_shape[-1] // (2**
                                                    self.opts.gen.p.spade_n_up)
        self.D: OmniDiscriminator = get_dis(
            self.opts, verbose=self.verbose).to(self.device)
        self.C: OmniClassifier = None
        if self.G.encoder is not None and self.opts.train.latent_domain_adaptation:
            self.C = get_classifier(self.opts,
                                    self.latent_shape,
                                    verbose=self.verbose).to(self.device)
        self.print_num_parameters()

        self.g_opt, self.g_scheduler = get_optimizer(self.G, self.opts.gen.opt)

        if get_num_params(self.D) > 0:
            self.d_opt, self.d_scheduler = get_optimizer(
                self.D, self.opts.dis.opt)
        else:
            self.d_opt, self.d_scheduler = None, None

        if self.C is not None:
            self.c_opt, self.c_scheduler = get_optimizer(
                self.C, self.opts.classifier.opt)
        else:
            self.c_opt, self.c_scheduler = None, None

        if self.opts.train.resume:
            self.resume()

        self.losses = get_losses(self.opts, self.verbose, device=self.device)

        if self.verbose > 0:
            for mode, mode_dict in self.loaders.items():
                for domain, domain_loader in mode_dict.items():
                    print("Loader {} {} : {}".format(
                        mode, domain, len(domain_loader.dataset)))

        # Create display images:
        print("Creating display images...", end="", flush=True)

        if type(self.opts.comet.display_size) == int:
            display_indices = range(self.opts.comet.display_size)
        else:
            display_indices = self.opts.comet.display_size

        self.display_images = {}
        for mode, mode_dict in self.loaders.items():
            self.display_images[mode] = {}
            for domain, domain_loader in mode_dict.items():

                self.display_images[mode][domain] = [
                    Dict(self.loaders[mode][domain].dataset[i])
                    for i in display_indices
                    if i < len(self.loaders[mode][domain].dataset)
                ]

        self.is_setup = True
예제 #3
0
    # -----  Test Setup  -----
    # ------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    target_domains = ["r", "s"]
    labels = domains_to_class_tensor(target_domains, one_hot=False).to(device)
    one_hot_labels = domains_to_class_tensor(target_domains,
                                             one_hot=True).to(device)
    cross_entropy = CrossEntropy()
    loss_l1 = L1Loss()

    # ------------------------------
    # -----  Test C.forward()  -----
    # ------------------------------
    z = torch.ones(len(target_domains), 128, 32, 32).to(device)
    latent_space = (128, 32, 32)
    C = get_classifier(opts, latent_space, 0).to(device)
    y = C(z)
    tprint(
        "output of classifier's shape for latent space {} :".format(
            list(z.shape[1:])),
        y.shape,
    )
    # --------------------------------
    # -----  Test cross_entropy  -----
    # --------------------------------

    tprint("CE loss:", cross_entropy(y, labels))
    # --------------------------
    # -----  Test l1_loss  -----
    # --------------------------
    tprint("l1 loss:", loss_l1(y, one_hot_labels))