def setup(self): """Prepare the trainer before it can be used to train the models: * initialize G and D * compute latent space dims and create classifier accordingly * creates 3 optimizers """ self.logger.global_step = 0 start_time = time() self.logger.time.start_time = start_time self.loaders = get_all_loaders(self.opts) self.G = get_gen(self.opts, verbose=self.verbose).to(self.device) self.latent_shape = self.compute_latent_shape() self.output_size = self.latent_shape[0] * 2 ** self.opts.gen.t.spade_n_up self.G.set_translation_decoder(self.latent_shape, self.device) self.D = get_dis(self.opts, verbose=self.verbose).to(self.device) self.C = get_classifier(self.opts, self.latent_shape, verbose=self.verbose).to( self.device ) self.P = {"s": get_mega_model()} # P => pseudo labeling models self.g_opt, self.g_scheduler = get_optimizer(self.G, self.opts.gen.opt) self.d_opt, self.d_scheduler = get_optimizer(self.D, self.opts.dis.opt) self.c_opt, self.c_scheduler = get_optimizer(self.C, self.opts.classifier.opt) self.set_losses() if self.verbose > 0: for mode, mode_dict in self.loaders.items(): for domain, domain_loader in mode_dict.items(): print( "Loader {} {} : {}".format( mode, domain, len(domain_loader.dataset) ) ) self.is_setup = True
def setup(self): """Prepare the trainer before it can be used to train the models: * initialize G and D * compute latent space dims and create classifier accordingly * creates 3 optimizers """ self.logger.global_step = 0 start_time = time() self.logger.time.start_time = start_time self.loaders = get_all_loaders(self.opts) self.G: OmniGenerator = get_gen(self.opts, verbose=self.verbose).to(self.device) if self.G.encoder is not None: self.latent_shape = self.compute_latent_shape() self.input_shape = self.compute_input_shape() self.painter_z_h = self.input_shape[-2] // (2** self.opts.gen.p.spade_n_up) self.painter_z_w = self.input_shape[-1] // (2** self.opts.gen.p.spade_n_up) self.D: OmniDiscriminator = get_dis( self.opts, verbose=self.verbose).to(self.device) self.C: OmniClassifier = None if self.G.encoder is not None and self.opts.train.latent_domain_adaptation: self.C = get_classifier(self.opts, self.latent_shape, verbose=self.verbose).to(self.device) self.print_num_parameters() self.g_opt, self.g_scheduler = get_optimizer(self.G, self.opts.gen.opt) if get_num_params(self.D) > 0: self.d_opt, self.d_scheduler = get_optimizer( self.D, self.opts.dis.opt) else: self.d_opt, self.d_scheduler = None, None if self.C is not None: self.c_opt, self.c_scheduler = get_optimizer( self.C, self.opts.classifier.opt) else: self.c_opt, self.c_scheduler = None, None if self.opts.train.resume: self.resume() self.losses = get_losses(self.opts, self.verbose, device=self.device) if self.verbose > 0: for mode, mode_dict in self.loaders.items(): for domain, domain_loader in mode_dict.items(): print("Loader {} {} : {}".format( mode, domain, len(domain_loader.dataset))) # Create display images: print("Creating display images...", end="", flush=True) if type(self.opts.comet.display_size) == int: display_indices = range(self.opts.comet.display_size) else: display_indices = self.opts.comet.display_size self.display_images = {} for mode, mode_dict in self.loaders.items(): self.display_images[mode] = {} for domain, domain_loader in mode_dict.items(): self.display_images[mode][domain] = [ Dict(self.loaders[mode][domain].dataset[i]) for i in display_indices if i < len(self.loaders[mode][domain].dataset) ] self.is_setup = True
# ----- Test Setup ----- # ------------------------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") target_domains = ["r", "s"] labels = domains_to_class_tensor(target_domains, one_hot=False).to(device) one_hot_labels = domains_to_class_tensor(target_domains, one_hot=True).to(device) cross_entropy = CrossEntropy() loss_l1 = L1Loss() # ------------------------------ # ----- Test C.forward() ----- # ------------------------------ z = torch.ones(len(target_domains), 128, 32, 32).to(device) latent_space = (128, 32, 32) C = get_classifier(opts, latent_space, 0).to(device) y = C(z) tprint( "output of classifier's shape for latent space {} :".format( list(z.shape[1:])), y.shape, ) # -------------------------------- # ----- Test cross_entropy ----- # -------------------------------- tprint("CE loss:", cross_entropy(y, labels)) # -------------------------- # ----- Test l1_loss ----- # -------------------------- tprint("l1 loss:", loss_l1(y, one_hot_labels))