def __init__(self, transformer_config, first_stage_config, cond_stage_config, permuter_config=None, ckpt_path=None, ignore_keys=[], first_stage_key="image", cond_stage_key="depth", downsample_cond_size=-1, pkeep=1.0, sos_token=0, unconditional=False, ): super().__init__() self.be_unconditional = unconditional self.sos_token = sos_token self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key self.init_first_stage_from_ckpt(first_stage_config) self.init_cond_stage_from_ckpt(cond_stage_config) if permuter_config is None: permuter_config = {"target": "taming.modules.transformer.permuter.Identity"} self.permuter = instantiate_from_config(config=permuter_config) self.transformer = instantiate_from_config(config=transformer_config) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.downsample_cond_size = downsample_cond_size self.pkeep = pkeep
def load_model_from_config(config, sd, gpu=True, eval_mode=True): if "ckpt_path" in config.params: st.warning("Deleting the restore-ckpt path from the config...") config.params.ckpt_path = None if "downsample_cond_size" in config.params: st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...") config.params.downsample_cond_size = -1 config.params["downsample_cond_factor"] = 0.5 try: if "ckpt_path" in config.params.first_stage_config.params: config.params.first_stage_config.params.ckpt_path = None st.warning("Deleting the first-stage restore-ckpt path from the config...") if "ckpt_path" in config.params.cond_stage_config.params: config.params.cond_stage_config.params.ckpt_path = None st.warning("Deleting the cond-stage restore-ckpt path from the config...") except: pass model = instantiate_from_config(config) if sd is not None: missing, unexpected = model.load_state_dict(sd, strict=False) st.info(f"Missing Keys in State Dict: {missing}") st.info(f"Unexpected Keys in State Dict: {unexpected}") if gpu: model.cuda() if eval_mode: model.eval() return {"model": model}
def __init__(self, ddconfig, lossconfig, n_embed, embed_dim, ckpt_path=None, ignore_keys=[], image_key="image", colorize_nlabels=None, monitor=None): super().__init__() self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = SimpleDecoder(**ddconfig, out_channels=3) self.loss = instantiate_from_config(lossconfig) self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25) self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.image_key = image_key if colorize_nlabels is not None: assert type(colorize_nlabels) == int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor
def load_model_from_config(config, sd, gpu=True, eval_mode=True): model = instantiate_from_config(config) if sd is not None: model.load_state_dict(sd) if gpu: model.cuda() if eval_mode: model.eval() return {"model": model}
def init_cond_stage_from_ckpt(self, config): if config == "__is_first_stage__": print("Using first stage also as cond stage.") self.cond_stage_model = self.first_stage_model elif config == "__is_unconditional__" or self.be_unconditional: print(f"Using no cond stage. Assuming the training is intended to be unconditional. " f"Prepending {self.sos_token} as a sos token.") self.be_unconditional = True self.cond_stage_key = self.first_stage_key self.cond_stage_model = SOSProvider(self.sos_token) else: model = instantiate_from_config(config) model = model.eval() model.train = disabled_train self.cond_stage_model = model
def __init__( self, ddconfig, lossconfig, n_embed, embed_dim, temperature_scheduler_config, ckpt_path=None, ignore_keys=[], image_key="image", colorize_nlabels=None, monitor=None, kl_weight=1e-8, remap=None, ): z_channels = ddconfig["z_channels"] super().__init__( ddconfig, lossconfig, n_embed, embed_dim, ckpt_path=None, ignore_keys=ignore_keys, image_key=image_key, colorize_nlabels=colorize_nlabels, monitor=monitor, ) self.loss.n_classes = n_embed self.vocab_size = n_embed self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0, remap=remap) self.temperature_scheduler = instantiate_from_config( temperature_scheduler_config) # annealing of temp if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def __init__( self, ddconfig, lossconfig, n_embed, embed_dim, ckpt_path=None, ignore_keys=[], image_key="image", colorize_nlabels=None, monitor=None, remap=None, sane_index_shape=False, # tell vector quantizer to return indices as bhw ): super().__init__() self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.image_key = image_key if colorize_nlabels is not None: assert type(colorize_nlabels) == int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor
def get_data(config): # get data data = instantiate_from_config(config.data) data.prepare_data() data.setup() return data
def init_first_stage_from_ckpt(self, config): model = instantiate_from_config(config) model = model.eval() model.train = disabled_train self.first_stage_model = model