class WarpingTransformer(Net2NetTransformer): def __init__(self, *args, **kwargs): kwargs["cond_stage_config"] = "__is_first_stage__" super().__init__(*args, **kwargs) self._midas = Midas() self._midas.eval() self._midas.train = disabled_train def get_xc(self, batch, N=None): if batch["dst_img"].device != self.device: for k in batch: if hasattr(batch[k], "to"): batch[k] = batch[k].to(device=self.device) x = self.get_input("dst_img", batch) x_src = self.get_input("src_img", batch) with torch.no_grad(): c, _ = self._midas.warp(x=x_src, points=batch["src_points"], R=batch["R_rel"], t=batch["t_rel"], K_src_inv=batch["K_inv"], K_dst=batch["K"]) if N is not None: x = x[:N] c = c[:N] return x, c
class AbstractWarper(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self._midas = Midas() self._midas.eval() self._midas.train = disabled_train for param in self._midas.parameters(): param.requires_grad = False self.n_unmasked = kwargs["n_unmasked"] # length of conditioning self.n_embd = kwargs["n_embd"] self.block_size = kwargs["block_size"] self.size = kwargs["size"] # h, w tuple self.start_idx = kwargs.get("start_idx", 0) # hint to not modify parts self._use_cache = False self.new_emb = None # cache self.new_pos = None # cache def set_cache(self, value): self._use_cache = value def get_embeddings(self, token_embeddings, position_embeddings, warpkwargs): if self._use_cache: assert not self.training, "Do you really want to use caching during training?" assert self.new_emb is not None assert self.new_pos is not None return self.new_emb, self.new_pos self.new_emb, self.new_pos = self._get_embeddings( token_embeddings, position_embeddings, warpkwargs) return self.new_emb, self.new_pos def _get_embeddings(self, token_embeddings, position_embeddings, warpkwargs): raise NotImplementedError() def forward(self, token_embeddings, position_embeddings, warpkwargs): new_emb, new_pos = self.get_embeddings(token_embeddings, position_embeddings, warpkwargs) new_emb = torch.cat( [new_emb, token_embeddings[:, self.n_unmasked:, :]], dim=1) b = new_pos.shape[0] new_pos = torch.cat([ new_pos, position_embeddings[:, self.n_unmasked:, :][b * [0], ...] ], dim=1) return new_emb, new_pos def _to_sequence(self, x): x = rearrange(x, 'b c h w -> b (h w) c') return x def _to_imglike(self, x): x = rearrange(x, 'b (h w) c -> b c h w', h=self.size[0]) return x
def __init__(self, transformer_config, first_stage_config, cond_stage_config, ckpt_path=None, ignore_keys=[], first_stage_key="image", cond_stage_key="depth", use_scheduler=False, scheduler_config=None, monitor="val/loss", downsample_cond_size=-1, pkeep=1.0, plot_cond_stage=False, log_det_sample=False, manipulate_latents=False, emb_stage_config=None, emb_stage_key="camera", emb_stage_trainable=True, top_k=None): super().__init__() if monitor is not None: self.monitor = monitor self.log_det_sample = log_det_sample self.manipulate_latents = manipulate_latents self.init_first_stage_from_ckpt(first_stage_config) self.init_cond_stage_from_ckpt(cond_stage_config) self.transformer = instantiate_from_config(config=transformer_config) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key self.downsample_cond_size = downsample_cond_size self.pkeep = pkeep self.use_scheduler = use_scheduler if use_scheduler: assert scheduler_config is not None self.scheduler_config = scheduler_config self.plot_cond_stage = plot_cond_stage self.emb_stage_key = emb_stage_key self.emb_stage_trainable = emb_stage_trainable and emb_stage_config is not None if self.emb_stage_trainable: print("### TRAINING EMB STAGE!!!") self.init_emb_stage_from_ckpt(emb_stage_config) self.top_k = top_k if top_k is not None else 100 self._midas = Midas() self._midas.eval() self._midas.train = disabled_train
def __init__(self, *args, **kwargs): super().__init__() self._midas = Midas() self._midas.eval() self._midas.train = disabled_train for param in self._midas.parameters(): param.requires_grad = False self.n_unmasked = kwargs["n_unmasked"] # length of conditioning self.n_embd = kwargs["n_embd"] self.block_size = kwargs["block_size"] self.size = kwargs["size"] # h, w tuple self.start_idx = kwargs.get("start_idx", 0) # hint to not modify parts self._use_cache = False self.new_emb = None # cache self.new_pos = None # cache
parser.add_argument( '--video', type=str, nargs='?', default=None, help='path to write video recording to. (no recording if unspecified).' ) opt = parser.parse_args() print(helptxt) if torch.cuda.is_available(): device = torch.device("cuda") else: print("Warning: Running on CPU---sampling might take a while...") device = torch.device("cpu") midas = Midas().eval().to(device) # init transformer renderer = Renderer(model=opt.model, device=device) if opt.path is None: try: import importlib.resources as pkg_resources except ImportError: import importlib_resources as pkg_resources with pkg_resources.path("geofree.examples", "artist.jpg") as path: example = load_as_example(path) else: path = opt.path if not os.path.isfile(path): Tk().withdraw()
def __init__(self, transformer_config, first_stage_config, cond_stage_config, depth_stage_config, merge_channels=None, ckpt_path=None, ignore_keys=[], first_stage_key="image", cond_stage_key="depth", use_scheduler=False, scheduler_config=None, monitor="val/loss", plot_cond_stage=False, log_det_sample=False, manipulate_latents=False, emb_stage_config=None, emb_stage_key="camera", emb_stage_trainable=True, top_k=None ): super().__init__() if monitor is not None: self.monitor = monitor self.log_det_sample = log_det_sample self.manipulate_latents = manipulate_latents self.init_first_stage_from_ckpt(first_stage_config) self.init_cond_stage_from_ckpt(cond_stage_config) self.init_depth_stage_from_ckpt(depth_stage_config) self.transformer = instantiate_from_config(config=transformer_config) self.merge_channels = merge_channels if self.merge_channels is not None: self.merge_conv = torch.nn.Conv2d(self.merge_channels, self.transformer.config.n_embd, kernel_size=1, padding=0, bias=False) self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key self.use_scheduler = use_scheduler if use_scheduler: assert scheduler_config is not None self.scheduler_config = scheduler_config self.plot_cond_stage = plot_cond_stage self.emb_stage_key = emb_stage_key self.emb_stage_trainable = emb_stage_trainable and emb_stage_config is not None if self.emb_stage_trainable: print("### TRAINING EMB STAGE!!!") self.init_emb_stage_from_ckpt(emb_stage_config) self.top_k = top_k if top_k is not None else 100 if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self._midas = Midas() self._midas.eval() self._midas.train = disabled_train self.warpkwargs_keys = { "x": "src_img", "points": "src_points", "R": "R_rel", "t": "t_rel", "K_dst": "K", "K_src_inv": "K_inv", }
class WarpGeoTransformer(pl.LightningModule): """GeoTransformer that also uses a warper in the transformer.""" def __init__(self, transformer_config, first_stage_config, cond_stage_config, depth_stage_config, merge_channels=None, ckpt_path=None, ignore_keys=[], first_stage_key="image", cond_stage_key="depth", use_scheduler=False, scheduler_config=None, monitor="val/loss", plot_cond_stage=False, log_det_sample=False, manipulate_latents=False, emb_stage_config=None, emb_stage_key="camera", emb_stage_trainable=True, top_k=None ): super().__init__() if monitor is not None: self.monitor = monitor self.log_det_sample = log_det_sample self.manipulate_latents = manipulate_latents self.init_first_stage_from_ckpt(first_stage_config) self.init_cond_stage_from_ckpt(cond_stage_config) self.init_depth_stage_from_ckpt(depth_stage_config) self.transformer = instantiate_from_config(config=transformer_config) self.merge_channels = merge_channels if self.merge_channels is not None: self.merge_conv = torch.nn.Conv2d(self.merge_channels, self.transformer.config.n_embd, kernel_size=1, padding=0, bias=False) self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key self.use_scheduler = use_scheduler if use_scheduler: assert scheduler_config is not None self.scheduler_config = scheduler_config self.plot_cond_stage = plot_cond_stage self.emb_stage_key = emb_stage_key self.emb_stage_trainable = emb_stage_trainable and emb_stage_config is not None if self.emb_stage_trainable: print("### TRAINING EMB STAGE!!!") self.init_emb_stage_from_ckpt(emb_stage_config) self.top_k = top_k if top_k is not None else 100 if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self._midas = Midas() self._midas.eval() self._midas.train = disabled_train self.warpkwargs_keys = { "x": "src_img", "points": "src_points", "R": "R_rel", "t": "t_rel", "K_dst": "K", "K_src_inv": "K_inv", } def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] for k in sd.keys(): for ik in ignore_keys: if k.startswith(ik): self.print("Deleting key {} from state_dict.".format(k)) del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print(f"Restored from {path} with {len(missing)} missing keys and {len(unexpected)} unexpected keys.") def init_first_stage_from_ckpt(self, config): model = instantiate_from_config(config) self.first_stage_model = model.eval() self.first_stage_model.train = disabled_train def init_cond_stage_from_ckpt(self, config): if config == "__is_first_stage__": print("Using first stage also as cond stage.") self.cond_stage_model = self.first_stage_model else: model = instantiate_from_config(config) self.cond_stage_model = model.eval() self.cond_stage_model.train = disabled_train def init_depth_stage_from_ckpt(self, config): model = instantiate_from_config(config) self.depth_stage_model = model.eval() self.depth_stage_model.train = disabled_train def init_emb_stage_from_ckpt(self, config): if config is None: self.emb_stage_model = None else: model = instantiate_from_config(config) self.emb_stage_model = model if not self.emb_stage_trainable: self.emb_stage_model.eval() self.emb_stage_model.train = disabled_train @torch.no_grad() def encode_to_z(self, x): quant_z, _, info = self.first_stage_model.encode(x) indices = info[2].view(quant_z.shape[0], -1) return quant_z, indices @torch.no_grad() def encode_to_c(self, c): quant_c, _, info = self.cond_stage_model.encode(c) indices = info[2].view(quant_c.shape[0], -1) return quant_c, indices @torch.no_grad() def encode_to_d(self, x): quant_z, _, info = self.depth_stage_model.encode(x) indices = info[2].view(quant_z.shape[0], -1) return quant_z, indices def encode_to_e(self, **kwargs): return self.emb_stage_model(**kwargs) def get_normalized_c(self, cdict, edict): with torch.no_grad(): quant_c, c_indices = self.encode_to_c(**cdict) scaled_idepth = self._midas.scaled_depth(cdict["c"], edict.pop("points"), return_inverse_depth=True) alpha = scaled_idepth.amax(dim=(1,2)) scaled_idepth = scaled_idepth/alpha[:,None,None] edict["t"] = edict["t"]*alpha[:,None] quant_d, d_indices = self.encode_to_d(scaled_idepth[:,None,:,:]*2.0-1.0) embeddings = self.encode_to_e(**edict) if self.merge_channels is None: # concat depth and src indices into 2*h*w conditioning indices dc_indices = torch.cat((d_indices, c_indices), dim=1) else: # use empty conditioning indices and compute h*w conditioning # embeddings dc_indices = torch.zeros_like(d_indices)[:,[]] merge = torch.cat((quant_d, quant_c), dim=1) merge = self.merge_conv(merge) merge = merge.permute(0,2,3,1) # to b,h,w,c merge = merge.reshape(merge.shape[0], merge.shape[1]*merge.shape[2], merge.shape[3]) # to b,hw,c embeddings = torch.cat((embeddings,merge), dim=1) # check that unmasking is correct total_cond_length = embeddings.shape[1] + dc_indices.shape[1] assert total_cond_length == self.transformer.config.n_unmasked, ( embeddings.shape[1], dc_indices.shape[1], self.transformer.config.n_unmasked) return quant_d, quant_c, dc_indices, embeddings def forward(self, xdict, cdict, edict, warpkwargs): # one step to produce the logits _, z_indices = self.encode_to_z(**xdict) _, _, dc_indices, embeddings = self.get_normalized_c(cdict, edict) cz_indices = torch.cat((dc_indices, z_indices), dim=1) # target includes all sequence elements (no need to handle first one # differently because we are conditioning) target = z_indices # make the prediction logits, _ = self.transformer(cz_indices[:, :-1], embeddings=embeddings, warpkwargs=warpkwargs) # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c) logits = logits[:, embeddings.shape[1]+dc_indices.shape[1]-1:] return logits, target def top_k_logits(self, logits, k): v, ix = torch.topk(logits, k) out = logits.clone() out[out < v[..., [-1]]] = -float('Inf') return out @torch.no_grad() def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None, callback=lambda k: None, embeddings=None, warpkwargs=None, **kwargs): # in the current variant we always use embeddings for camera assert embeddings is not None # check n_unmasked and conditioning length total_cond_length = embeddings.shape[1] + c.shape[1] assert total_cond_length == self.transformer.config.n_unmasked, ( embeddings.shape[1], c.shape[1], self.transformer.config.n_unmasked) x = torch.cat((c,x),dim=1) block_size = self.transformer.get_block_size() assert not self.transformer.training for k in range(steps): callback(k) assert x.size(1) <= block_size # make sure model can see conditioning x_cond = x logits, _ = self.transformer(x_cond, embeddings=embeddings, warpkwargs=warpkwargs) # for the next steps, reuse precomputed embeddings for conditioning self.transformer.warper.set_cache(True) # pluck the logits at the final step and scale by temperature logits = logits[:, -1, :] / temperature # optionally crop probabilities to only the top k options if top_k is not None: logits = self.top_k_logits(logits, top_k) # apply softmax to convert to probabilities probs = F.softmax(logits, dim=-1) # sample from the distribution or take the most likely if sample: ix = torch.multinomial(probs, num_samples=1) else: _, ix = torch.topk(probs, k=1, dim=-1) # append to the sequence and continue x = torch.cat((x, ix), dim=1) # disable caching again self.transformer.warper.set_cache(False) # cut off conditioning x = x[:, c.shape[1]:] return x @torch.no_grad() def decode_to_img(self, index, zshape): bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) quant_z = self.first_stage_model.quantize.get_codebook_entry( index.reshape(-1), shape=bhwc) x = self.first_stage_model.decode(quant_z) return x @torch.no_grad() def log_images(self, batch, temperature=None, top_k=None, callback=None, N=4, half_sample=True, sample=True, det_sample=None, **kwargs): det_sample = det_sample if det_sample is not None else self.log_det_sample log = dict() xdict, cdict, edict = self.get_xce(batch, N) for k in xdict: xdict[k] = xdict[k].to(device=self.device) for k in cdict: cdict[k] = cdict[k].to(device=self.device) for k in edict: edict[k] = edict[k].to(device=self.device) warpkwargs = self.get_warpkwargs(batch, N) for k in warpkwargs: warpkwargs[k] = warpkwargs[k].to(device=self.device) log["inputs"] = xdict["x"] log["conditioning"] = cdict["c"] quant_z, z_indices = self.encode_to_z(**xdict) quant_d, quant_c, dc_indices, embeddings = self.get_normalized_c(cdict, edict) if half_sample: # create a "half"" sample z_start_indices = z_indices[:,:z_indices.shape[1]//2] index_sample = self.sample(z_start_indices, dc_indices, steps=z_indices.shape[1]-z_start_indices.shape[1], temperature=temperature if temperature is not None else 1.0, sample=True, top_k=top_k if top_k is not None else self.top_k, callback=callback if callback is not None else lambda k: None, embeddings=embeddings, warpkwargs=warpkwargs) x_sample = self.decode_to_img(index_sample, quant_z.shape) log["samples_half"] = x_sample if sample: # sample z_start_indices = z_indices[:, :0] t1 = time.time() index_sample = self.sample(z_start_indices, dc_indices, steps=z_indices.shape[1], temperature=temperature if temperature is not None else 1.0, sample=True, top_k=top_k if top_k is not None else 100, callback=callback if callback is not None else lambda k: None, embeddings=embeddings, warpkwargs=warpkwargs) if not hasattr(self, "sampling_time"): self.sampling_time = time.time() - t1 print(f"Full sampling takes about {self.sampling_time:.2f} seconds.") x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape) log["samples_nopix"] = x_sample_nopix if det_sample: # det sample z_start_indices = z_indices[:, :0] index_sample = self.sample(z_start_indices, dc_indices, steps=z_indices.shape[1], sample=False, callback=callback if callback is not None else lambda k: None, embeddings=embeddings, warpkwargs=warpkwargs) x_sample_det = self.decode_to_img(index_sample, quant_z.shape) log["samples_det"] = x_sample_det # reconstruction x_rec = self.decode_to_img(z_indices, quant_z.shape) log["reconstructions"] = x_rec if self.plot_cond_stage: cond_rec = self.cond_stage_model.decode(quant_c) log["conditioning_rec"] = cond_rec depth_rec = self.depth_stage_model.decode(quant_d) log["depth_rec"] = depth_rec return log def get_input(self, key, batch, heuristics=True): x = batch[key] if heuristics: if len(x.shape) == 3: x = x[..., None] x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) if x.dtype == torch.double: x = x.float() return x def get_xce(self, batch, N=None): xdict = dict() for k, v in self.first_stage_key.items(): xdict[k] = self.get_input(v, batch, heuristics=k=="x")[:N] cdict = dict() for k, v in self.cond_stage_key.items(): cdict[k] = self.get_input(v, batch, heuristics=k=="c")[:N] edict = dict() for k, v in self.emb_stage_key.items(): edict[k] = self.get_input(v, batch, heuristics=False)[:N] return xdict, cdict, edict def get_warpkwargs(self, batch, N=None): kwargs = dict() for k, v in self.warpkwargs_keys.items(): kwargs[k] = self.get_input(v, batch, heuristics=k=="x")[:N] return kwargs def compute_loss(self, logits, targets, split="train"): loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) return loss, {f"{split}/loss": loss.detach()} def shared_step(self, batch, batch_idx): x, c, e = self.get_xce(batch) warpkwargs = self.get_warpkwargs(batch) logits, target = self(x, c, e, warpkwargs) return logits, target def training_step(self, batch, batch_idx): logits, target = self.shared_step(batch, batch_idx) loss, log_dict = self.compute_loss(logits, target, split="train") self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss def validation_step(self, batch, batch_idx): logits, target = self.shared_step(batch, batch_idx) loss, log_dict = self.compute_loss(logits, target, split="val") self.log("val/loss", loss, prog_bar=True, logger=True, on_step=False, on_epoch=True) return log_dict def configure_optimizers(self): # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.transformer.named_modules(): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name if fpn.startswith("warper._midas"): continue if pn.endswith('bias'): # all biases will not be decayed no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) # special case the position embedding parameter in the root GPT module as not decayed no_decay.add('pos_emb') if hasattr(self.transformer.warper, "pos_emb"): no_decay.add('warper.pos_emb') # handle meta positional embeddings for pn, p in self.transformer.warper.named_parameters(): if pn.endswith("pos_emb"): no_decay.add(f"warper.{pn}") # validate that we considered every parameter param_dict = {pn: p for pn, p in self.transformer.named_parameters() if not pn.startswith("warper._midas")} inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ % (str(param_dict.keys() - union_params), ) # create the pytorch optimizer object optim_groups = [ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] extra_parameters = list() if self.emb_stage_trainable: extra_parameters += list(self.emb_stage_model.parameters()) if hasattr(self, "merge_conv"): extra_parameters += list(self.merge_conv.parameters()) else: assert self.merge_channels is None optim_groups.append({"params": extra_parameters, "weight_decay": 0.0}) print(f"Optimizing {len(extra_parameters)} extra parameters.") optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95)) if self.use_scheduler: scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }] return [optimizer], scheduler return optimizer
class WarpingFeatureTransformer(pl.LightningModule): def __init__(self, transformer_config, first_stage_config, cond_stage_config, ckpt_path=None, ignore_keys=[], first_stage_key="image", cond_stage_key="depth", use_scheduler=False, scheduler_config=None, monitor="val/loss", downsample_cond_size=-1, pkeep=1.0, plot_cond_stage=False, log_det_sample=False, manipulate_latents=False, emb_stage_config=None, emb_stage_key="camera", emb_stage_trainable=True, top_k=None): super().__init__() if monitor is not None: self.monitor = monitor self.log_det_sample = log_det_sample self.manipulate_latents = manipulate_latents self.init_first_stage_from_ckpt(first_stage_config) self.init_cond_stage_from_ckpt(cond_stage_config) self.transformer = instantiate_from_config(config=transformer_config) if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key self.downsample_cond_size = downsample_cond_size self.pkeep = pkeep self.use_scheduler = use_scheduler if use_scheduler: assert scheduler_config is not None self.scheduler_config = scheduler_config self.plot_cond_stage = plot_cond_stage self.emb_stage_key = emb_stage_key self.emb_stage_trainable = emb_stage_trainable and emb_stage_config is not None if self.emb_stage_trainable: print("### TRAINING EMB STAGE!!!") self.init_emb_stage_from_ckpt(emb_stage_config) self.top_k = top_k if top_k is not None else 100 self._midas = Midas() self._midas.eval() self._midas.train = disabled_train def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] for k in sd.keys(): for ik in ignore_keys: if k.startswith(ik): self.print("Deleting key {} from state_dict.".format(k)) del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print( f"Restored from {path} with {len(missing)} missing keys and {len(unexpected)} unexpected keys." ) def init_first_stage_from_ckpt(self, config): model = instantiate_from_config(config) self.first_stage_model = model.eval() def init_cond_stage_from_ckpt(self, config): if config == "__is_first_stage__": print("Using first stage also as cond stage.") self.cond_stage_model = self.first_stage_model else: model = instantiate_from_config(config) self.cond_stage_model = model.eval() def init_emb_stage_from_ckpt(self, config): if config is None: self.emb_stage_model = None else: model = instantiate_from_config(config) self.emb_stage_model = model if not self.emb_stage_trainable: self.emb_stage_model.eval() def forward(self, xdict, cdict, e=None): # one step to produce the logits _, z_indices = self.encode_to_z(**xdict) _, c_indices = self.encode_to_c(**cdict) embeddings = None if e is not None: embeddings = self.encode_to_e(e) if self.training and self.pkeep < 1.0: mask = torch.bernoulli( self.pkeep * torch.ones(z_indices.shape, device=z_indices.device)) mask = mask.round().to(dtype=torch.int64) r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) a_indices = mask * z_indices + (1 - mask) * r_indices else: a_indices = z_indices cz_indices = torch.cat((c_indices, a_indices), dim=1) # target includes all sequence elements (no need to handle first one # differently because we are conditioning) target = z_indices # make the prediction logits, _ = self.transformer(cz_indices[:, :-1], embeddings=embeddings) # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c) logits = logits[:, c_indices.shape[1] - 1:] return logits, target def top_k_logits(self, logits, k): v, ix = torch.topk(logits, k) out = logits.clone() out[out < v[..., [-1]]] = -float('Inf') return out @torch.no_grad() def sample(self, x, c, steps, temperature=1.0, sample=False, top_k=None, callback=lambda k: None, embeddings=None, **kwargs): x = torch.cat((c, x), dim=1) block_size = self.transformer.get_block_size() assert not self.transformer.training for k in range(steps): callback(k) assert x.size( 1) <= block_size # make sure model can see conditioning x_cond = x if x.size( 1) <= block_size else x[:, -block_size:] # crop context if needed logits, _ = self.transformer(x_cond, embeddings=embeddings) # pluck the logits at the final step and scale by temperature logits = logits[:, -1, :] / temperature # optionally crop probabilities to only the top k options if top_k is not None: logits = self.top_k_logits(logits, top_k) # apply softmax to convert to probabilities probs = F.softmax(logits, dim=-1) # sample from the distribution or take the most likely if sample: ix = torch.multinomial(probs, num_samples=1) else: _, ix = torch.topk(probs, k=1, dim=-1) # append to the sequence and continue x = torch.cat((x, ix), dim=1) # cut off conditioning x = x[:, c.shape[1]:] return x @torch.no_grad() def encode_to_z(self, x): quant_z, _, info = self.first_stage_model.encode(x) indices = info[2].view(quant_z.shape[0], -1) return quant_z, indices @torch.no_grad() def encode_to_c(self, c, points, R, t, K, K_inv): if self.downsample_cond_size > -1: assert False, "Rescaling of intrinsics not implemented at this point." c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size)) # step into # quant_c, _, info = self.cond_stage_model.encode(c) h = self.cond_stage_model.encoder(c) h = self.cond_stage_model.quant_conv(h) # now warp h, _ = self._midas.warp_features(f=h, x=c, points=points, R=R, t=t, K_src_inv=K_inv, K_dst=K) # continue with quantization quant_c, _, info = self.cond_stage_model.quantize(h) if quant_c is not None: # this is the standard case indices = info[2].view(quant_c.shape[0], -1) else: # e.g. for SlotPretraining. indices = info[2] return quant_c, indices @torch.no_grad() def decode_to_img(self, index, zshape): bhwc = (zshape[0], zshape[2], zshape[3], zshape[1]) quant_z = self.first_stage_model.quantize.get_codebook_entry( index.reshape(-1), shape=bhwc) x = self.first_stage_model.decode(quant_z) return x @torch.no_grad() def log_images(self, batch, temperature=None, top_k=None, callback=None, N=4, half_sample=True, sample=True, det_sample=None, **kwargs): det_sample = det_sample if det_sample is not None else self.log_det_sample log = dict() xdict, cdict = self.get_xc(batch, N) for k in xdict: xdict[k] = xdict[k].to(device=self.device) for k in cdict: cdict[k] = cdict[k].to(device=self.device) x = xdict["x"] c = cdict["c"] quant_z, z_indices = self.encode_to_z(**xdict) quant_c, c_indices = self.encode_to_c(**cdict) embeddings = None if self.emb_stage_model is not None and (half_sample or sample or det_sample): e = self.get_e(batch, N) e = e.to(device=self.device) embeddings = self.emb_stage_model(e) if half_sample: # create a "half"" sample z_start_indices = z_indices[:, :z_indices.shape[1] // 2] index_sample = self.sample( z_start_indices, c_indices, steps=z_indices.shape[1] - z_start_indices.shape[1], temperature=temperature if temperature is not None else 1.0, sample=True, top_k=top_k if top_k is not None else self.top_k, callback=callback if callback is not None else lambda k: None, embeddings=embeddings) x_sample = self.decode_to_img(index_sample, quant_z.shape) log["samples_half"] = x_sample if sample: # sample z_start_indices = z_indices[:, :0] t1 = time.time() index_sample = self.sample( z_start_indices, c_indices, steps=z_indices.shape[1], temperature=temperature if temperature is not None else 1.0, sample=True, top_k=top_k if top_k is not None else 100, callback=callback if callback is not None else lambda k: None, embeddings=embeddings) if not hasattr(self, "sampling_time"): self.sampling_time = time.time() - t1 print( f"Full sampling takes about {self.sampling_time:.2f} seconds." ) x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape) log["samples_nopix"] = x_sample_nopix if det_sample: # det sample z_start_indices = z_indices[:, :0] index_sample = self.sample( z_start_indices, c_indices, steps=z_indices.shape[1], sample=False, callback=callback if callback is not None else lambda k: None, embeddings=embeddings) x_sample_det = self.decode_to_img(index_sample, quant_z.shape) log["samples_det"] = x_sample_det # reconstruction x_rec = self.decode_to_img(z_indices, quant_z.shape) log["inputs"] = x log["reconstructions"] = x_rec if self.plot_cond_stage: cond_rec = self.cond_stage_model.decode(quant_c) if self.cond_stage_key == "segmentation": # get image from segmentation mask num_classes = cond_rec.shape[1] c = torch.argmax(c, dim=1, keepdim=True) c = F.one_hot(c, num_classes=num_classes) c = c.squeeze(1).permute(0, 3, 1, 2).float() c = self.cond_stage_model.to_rgb(c) cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True) cond_rec = F.one_hot(cond_rec, num_classes=num_classes) cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float() cond_rec = self.cond_stage_model.to_rgb(cond_rec) log["conditioning_rec"] = cond_rec log["conditioning"] = c return log def get_input(self, key, batch, heuristics=True): x = batch[key] if heuristics: if key == "caption": x = list(x[0]) # coco specific hack else: if len(x.shape) == 3: x = x[..., None] if key not in ["coordinates_bbox"]: x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) if x.dtype == torch.double: x = x.float() return x def get_xc(self, batch, N=None): xdict = dict() for k, v in self.first_stage_key.items(): xdict[k] = self.get_input(v, batch, heuristics=k == "x")[:N] cdict = dict() for k, v in self.cond_stage_key.items(): cdict[k] = self.get_input(v, batch, heuristics=k == "c")[:N] return xdict, cdict def get_e(self, batch, N=None): e = self.get_input(self.emb_stage_key, batch) if N is not None: e = e[:N] return e def compute_loss(self, logits, targets, split="train"): loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) return loss, {f"{split}/loss": loss.detach()} def shared_step(self, batch, batch_idx): x, c = self.get_xc(batch) logits, target = self(x, c) return logits, target def training_step(self, batch, batch_idx): logits, target = self.shared_step(batch, batch_idx) loss, log_dict = self.compute_loss(logits, target, split="train") self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss def validation_step(self, batch, batch_idx): logits, target = self.shared_step(batch, batch_idx) loss, log_dict = self.compute_loss(logits, target, split="val") self.log("val/loss", loss, prog_bar=True, logger=True, on_step=False, on_epoch=True) return log_dict def configure_optimizers(self): # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, ) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.transformer.named_modules(): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name if pn.endswith('bias'): # all biases will not be decayed no_decay.add(fpn) elif pn.endswith('weight') and isinstance( m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) elif pn.endswith('weight') and isinstance( m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) # special case the position embedding parameter in the root GPT module as not decayed no_decay.add('pos_emb') # validate that we considered every parameter param_dict = {pn: p for pn, p in self.transformer.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay assert len( inter_params ) == 0, "parameters %s made it into both decay/no_decay sets!" % ( str(inter_params), ) assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ % (str(param_dict.keys() - union_params), ) # create the pytorch optimizer object optim_groups = [ { "params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01 }, { "params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0 }, ] if self.emb_stage_trainable: optim_groups.append({ "params": self.emb_stage_model.parameters(), "weight_decay": 0.0 }) optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95)) if self.use_scheduler: scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [{ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }] return [optimizer], scheduler return optimizer
def __init__(self, *args, **kwargs): kwargs["cond_stage_config"] = "__is_first_stage__" super().__init__(*args, **kwargs) self._midas = Midas() self._midas.eval() self._midas.train = disabled_train
def __init__(self, transformer_config, first_stage_config, cond_stage_config, depth_stage_config, merge_channels=None, use_depth=True, ckpt_path=None, ignore_keys=[], first_stage_key="image", cond_stage_key="depth", use_scheduler=False, scheduler_config=None, monitor="val/loss", pkeep=1.0, plot_cond_stage=False, log_det_sample=False, manipulate_latents=False, emb_stage_config=None, emb_stage_key="camera", emb_stage_trainable=True, top_p=None, top_k=None ): super().__init__() if monitor is not None: self.monitor = monitor self.log_det_sample = log_det_sample self.manipulate_latents = manipulate_latents self.init_first_stage_from_ckpt(first_stage_config) self.init_cond_stage_from_ckpt(cond_stage_config) self.init_depth_stage_from_ckpt(depth_stage_config) self.transformer = instantiate_from_config(config=transformer_config) self.merge_channels = merge_channels if self.merge_channels is not None: self.merge_conv = torch.nn.Conv2d(self.merge_channels, self.transformer.config.n_embd, kernel_size=1, padding=0, bias=False) self.use_depth = use_depth if not self.use_depth: assert self.merge_channels is None self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key self.use_scheduler = use_scheduler if use_scheduler: assert scheduler_config is not None self.scheduler_config = scheduler_config self.plot_cond_stage = plot_cond_stage self.emb_stage_key = emb_stage_key self.emb_stage_trainable = emb_stage_trainable and emb_stage_config is not None self.init_emb_stage_from_ckpt(emb_stage_config) self.top_p = top_p if top_p is not None else 0.95 try: tk = self.first_stage_model.quantize.n_e except: tk = 100 self.top_k = top_k if top_k is not None else tk if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self._midas = Midas() self._midas.eval() self._midas.train = disabled_train