def init_net(net, init_type="normal", init_gain=0.02, gpu_ids=[]): if len(gpu_ids) > 0: assert torch.cuda.is_available() net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) init_weights(net, init_type=init_type, init_gain=init_gain) return net
def get_classifier(opts, latent_space, verbose): C = OmniClassifier(latent_space, opts.classifier.loss) init_weights( C, init_type=opts.classifier.init_type, init_gain=opts.classifier.init_gain, verbose=verbose, ) return C
def get_dis(opts, verbose): disc = OmniDiscriminator(opts) for task, model in disc.items(): for domain_model in model.values(): init_weights( domain_model, init_type=opts.dis[task].init_type, init_gain=opts.dis[task].init_gain, verbose=verbose, ) return disc
def get_gen(opts, latent_shape=None, verbose=0): G = OmniGenerator(opts, latent_shape, verbose) for model in G.decoders: net = G.decoders[model] if isinstance(net, nn.ModuleDict): for domain_model in net.keys(): init_weights( net[domain_model], init_type=opts.gen[model].init_type, init_gain=opts.gen[model].init_gain, verbose=verbose, ) else: init_weights( G.decoders[model], init_type=opts.gen[model].init_type, init_gain=opts.gen[model].init_gain, verbose=verbose, ) if G.encoder is not None and opts.gen.encoder.architecture != "deeplabv2": init_weights( G.encoder, init_type=opts.gen.encoder.init_type, init_gain=opts.gen.encoder.init_gain, verbose=verbose, ) # Init painter weights init_weights( G.painter, init_type=opts.gen.p.init_type, init_gain=opts.gen.p.init_gain, verbose=verbose, ) return G
def set_translation_decoder(self, latent_shape, device): if self.opts.gen.t.use_bit_conditioning: if not self.opts.gen.t.use_spade: raise ValueError( "cannot have use_bit_conditioning but not use_spade") self.decoders["t"] = SpadeTranslationDict(latent_shape, self.opts) self.decoders["t"] = self.decoders["t"].to(device) elif self.opts.gen.t.use_spade: self.decoders["t"] = nn.ModuleDict({ "f": SpadeTranslationDecoder(latent_shape, self.opts).to(device), "n": SpadeTranslationDecoder(latent_shape, self.opts).to(device), }) for k in ["f", "n"]: init_weights( self.decoders["t"][k], init_type=self.opts.gen.t.init_type, init_gain=self.opts.gen.t.init_gain, verbose=self.verbose, ) else: pass # not using spade in anyway: do nothing
def get_gen(opts, latent_shape=None, verbose=0): G = OmniGenerator(opts, latent_shape, verbose) for model in G.decoders: if model == "t": if opts.gen.t.use_spade or opts.gen.t.use_bit_conditioning: continue net = G.decoders[model] if isinstance(net, nn.ModuleDict): for domain_model in net.keys(): init_weights( net[domain_model], init_type=opts.gen[model].init_type, init_gain=opts.gen[model].init_gain, verbose=verbose, ) else: init_weights( G.decoders[model], init_type=opts.gen[model].init_type, init_gain=opts.gen[model].init_gain, verbose=verbose, ) return G