def __init__(self, ddconfig, lossconfig, n_embed, embed_dim, ckpt_path=None, ignore_keys=[], image_key="image", colorize_nlabels=None, monitor=None, batch_resize_range=None, scheduler_config=None, lr_g_factor=1.0, ): super().__init__() self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25) self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) if colorize_nlabels is not None: assert type(colorize_nlabels)==int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor self.batch_resize_range = batch_resize_range if self.batch_resize_range is not None: print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.scheduler_config = scheduler_config self.lr_g_factor = lr_g_factor
def init_cond_stage_from_ckpt(self, config): if config == "__is_first_stage__": print("Using first stage also as cond stage.") self.cond_stage_model = self.first_stage_model else: model = instantiate_from_config(config) self.cond_stage_model = model.eval()
def configure_optimizers(self): lr_d = self.learning_rate lr_g = self.lr_g_factor*self.learning_rate print("lr_d", lr_d) print("lr_g", lr_g) opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ list(self.decoder.parameters())+ list(self.quantize.parameters())+ list(self.quant_conv.parameters())+ list(self.post_quant_conv.parameters()), lr=lr_g, betas=(0.5, 0.9)) opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)) if self.scheduler_config is not None: scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }, { 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }, ] return [opt_ae, opt_disc], scheduler return [opt_ae, opt_disc], []
def init_emb_stage_from_ckpt(self, config): if config is None: self.emb_stage_model = None else: model = instantiate_from_config(config) self.emb_stage_model = model if not self.emb_stage_trainable: self.emb_stage_model.eval()
def configure_optimizers(self): # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, ) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.transformer.named_modules(): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name if pn.endswith('bias'): # all biases will not be decayed no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) # special case the position embedding parameter in the root GPT module as not decayed no_decay.add('pos_emb') # validate that we considered every parameter param_dict = {pn: p for pn, p in self.transformer.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ % (str(param_dict.keys() - union_params), ) # create the pytorch optimizer object optim_groups = [ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] extra_parameters = list() if self.emb_stage_trainable: extra_parameters += list(self.emb_stage_model.parameters()) if hasattr(self, "merge_conv"): extra_parameters += list(self.merge_conv.parameters()) else: assert self.merge_channels is None optim_groups.append({"params": extra_parameters, "weight_decay": 0.0}) print(f"Optimizing {len(extra_parameters)} extra parameters.") optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95)) if self.use_scheduler: scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }] return [optimizer], scheduler return optimizer
def __init__( self, transformer_config, first_stage_config, cond_stage_config, ckpt_path=None, ignore_keys=[], first_stage_key="image", cond_stage_key="depth", use_scheduler=False, scheduler_config=None, monitor="val/loss", downsample_cond_size=-1, pkeep=1.0, plot_cond_stage=False, log_det_sample=False, manipulate_latents=False, emb_stage_config=None, emb_stage_key="camera", emb_stage_trainable=True, top_k=None, unconditional=False, sos_token=0, use_first_stage_get_input=False, ): super().__init__() self.use_first_stage_get_input = use_first_stage_get_input if monitor is not None: self.monitor = monitor self.be_unconditional = unconditional self.sos_token = sos_token self.first_stage_key = first_stage_key self.log_det_sample = log_det_sample self.manipulate_latents = manipulate_latents self.init_first_stage_from_ckpt(first_stage_config) self.init_cond_stage_from_ckpt(cond_stage_config) self.transformer = instantiate_from_config(config=transformer_config) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) if not hasattr(self, "cond_stage_key"): self.cond_stage_key = cond_stage_key self.downsample_cond_size = downsample_cond_size self.pkeep = pkeep self.use_scheduler = use_scheduler if use_scheduler: assert scheduler_config is not None self.scheduler_config = scheduler_config self.plot_cond_stage = plot_cond_stage self.emb_stage_key = emb_stage_key self.emb_stage_trainable = emb_stage_trainable and emb_stage_config is not None if self.emb_stage_trainable: print("### TRAINING EMB STAGE!!!") self.init_emb_stage_from_ckpt(emb_stage_config) self.top_k = top_k if top_k is not None else 100
def __init__(self, **kwargs): assert "n_unmasked" in kwargs and kwargs["n_unmasked"] != 0 warper_config = kwargs.pop("warper_config") if warper_config.params is None: warper_config.params = dict() warper_config.params["n_unmasked"] = kwargs["n_unmasked"] warper_config.params["block_size"] = kwargs["block_size"] warper_config.params["n_embd"] = kwargs["n_embd"] super().__init__(**kwargs) self.warper = instantiate_from_config(warper_config)
def init_cond_stage_from_ckpt(self, config): if config == "__is_first_stage__": print("Using first stage also as cond stage.") self.cond_stage_model = self.first_stage_model elif config == "__is_unconditional__" or self.be_unconditional: print( f"Using no cond stage. Assuming the training is intended to be unconditional. " f"Prepending {self.sos_token} as a sos token.") self.be_unconditional = True self.cond_stage_key = self.first_stage_key self.cond_stage_model = SOSProvider(self.sos_token) else: model = instantiate_from_config(config) self.cond_stage_model = model.eval()
def __init__(self, transformer_config, first_stage_config, cond_stage_config, ckpt_path=None, ignore_keys=[], first_stage_key="dst_img", cond_stage_key="src_img", use_scheduler=False, scheduler_config=None, monitor="val/loss", pkeep=1.0, plot_cond_stage=False, log_det_sample=False, top_k=None): super().__init__() if monitor is not None: self.monitor = monitor self.log_det_sample = log_det_sample self.init_first_stage_from_ckpt(first_stage_config) self.init_cond_stage_from_ckpt(cond_stage_config) self.transformer = instantiate_from_config(config=transformer_config) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key self.pkeep = pkeep self.use_scheduler = use_scheduler if use_scheduler: assert scheduler_config is not None self.scheduler_config = scheduler_config self.plot_cond_stage = plot_cond_stage self.top_k = top_k if top_k is not None else 100 self.warpkwargs_keys = { "x": "src_img", "points": "src_points", "R": "R_rel", "t": "t_rel", "K_dst": "K", "K_src_inv": "K_inv", }
def init_depth_stage_from_ckpt(self, config): model = instantiate_from_config(config) self.depth_stage_model = model.eval() self.depth_stage_model.train = disabled_train
def __init__(self, transformer_config, first_stage_config, cond_stage_config, depth_stage_config, merge_channels=None, ckpt_path=None, ignore_keys=[], first_stage_key="image", cond_stage_key="depth", use_scheduler=False, scheduler_config=None, monitor="val/loss", plot_cond_stage=False, log_det_sample=False, manipulate_latents=False, emb_stage_config=None, emb_stage_key="camera", emb_stage_trainable=True, top_k=None ): super().__init__() if monitor is not None: self.monitor = monitor self.log_det_sample = log_det_sample self.manipulate_latents = manipulate_latents self.init_first_stage_from_ckpt(first_stage_config) self.init_cond_stage_from_ckpt(cond_stage_config) self.init_depth_stage_from_ckpt(depth_stage_config) self.transformer = instantiate_from_config(config=transformer_config) self.merge_channels = merge_channels if self.merge_channels is not None: self.merge_conv = torch.nn.Conv2d(self.merge_channels, self.transformer.config.n_embd, kernel_size=1, padding=0, bias=False) self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key self.use_scheduler = use_scheduler if use_scheduler: assert scheduler_config is not None self.scheduler_config = scheduler_config self.plot_cond_stage = plot_cond_stage self.emb_stage_key = emb_stage_key self.emb_stage_trainable = emb_stage_trainable and emb_stage_config is not None if self.emb_stage_trainable: print("### TRAINING EMB STAGE!!!") self.init_emb_stage_from_ckpt(emb_stage_config) self.top_k = top_k if top_k is not None else 100 if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self._midas = Midas() self._midas.eval() self._midas.train = disabled_train self.warpkwargs_keys = { "x": "src_img", "points": "src_points", "R": "R_rel", "t": "t_rel", "K_dst": "K", "K_src_inv": "K_inv", }
def init_first_stage_from_ckpt(self, config): model = instantiate_from_config(config) self.first_stage_model = model.eval()
def __init__(self, transformer_config, first_stage_config, cond_stage_config, depth_stage_config, merge_channels=None, use_depth=True, ckpt_path=None, ignore_keys=[], first_stage_key="image", cond_stage_key="depth", use_scheduler=False, scheduler_config=None, monitor="val/loss", pkeep=1.0, plot_cond_stage=False, log_det_sample=False, manipulate_latents=False, emb_stage_config=None, emb_stage_key="camera", emb_stage_trainable=True, top_p=None, top_k=None ): super().__init__() if monitor is not None: self.monitor = monitor self.log_det_sample = log_det_sample self.manipulate_latents = manipulate_latents self.init_first_stage_from_ckpt(first_stage_config) self.init_cond_stage_from_ckpt(cond_stage_config) self.init_depth_stage_from_ckpt(depth_stage_config) self.transformer = instantiate_from_config(config=transformer_config) self.merge_channels = merge_channels if self.merge_channels is not None: self.merge_conv = torch.nn.Conv2d(self.merge_channels, self.transformer.config.n_embd, kernel_size=1, padding=0, bias=False) self.use_depth = use_depth if not self.use_depth: assert self.merge_channels is None self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key self.use_scheduler = use_scheduler if use_scheduler: assert scheduler_config is not None self.scheduler_config = scheduler_config self.plot_cond_stage = plot_cond_stage self.emb_stage_key = emb_stage_key self.emb_stage_trainable = emb_stage_trainable and emb_stage_config is not None self.init_emb_stage_from_ckpt(emb_stage_config) self.top_p = top_p if top_p is not None else 0.95 try: tk = self.first_stage_model.quantize.n_e except: tk = 100 self.top_k = top_k if top_k is not None else tk if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self._midas = Midas() self._midas.eval() self._midas.train = disabled_train