示例#1
0
 def __init__(self,
              ddconfig,
              lossconfig,
              n_embed,
              embed_dim,
              ckpt_path=None,
              ignore_keys=[],
              image_key="image",
              colorize_nlabels=None,
              monitor=None,
              batch_resize_range=None,
              scheduler_config=None,
              lr_g_factor=1.0,
              ):
     super().__init__()
     self.image_key = image_key
     self.encoder = Encoder(**ddconfig)
     self.decoder = Decoder(**ddconfig)
     self.loss = instantiate_from_config(lossconfig)
     self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
     self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
     self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
     if colorize_nlabels is not None:
         assert type(colorize_nlabels)==int
         self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
     if monitor is not None:
         self.monitor = monitor
     self.batch_resize_range = batch_resize_range
     if self.batch_resize_range is not None:
         print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
     if ckpt_path is not None:
         self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
     self.scheduler_config = scheduler_config
     self.lr_g_factor = lr_g_factor
示例#2
0
 def init_cond_stage_from_ckpt(self, config):
     if config == "__is_first_stage__":
         print("Using first stage also as cond stage.")
         self.cond_stage_model = self.first_stage_model
     else:
         model = instantiate_from_config(config)
         self.cond_stage_model = model.eval()
示例#3
0
    def configure_optimizers(self):
        lr_d = self.learning_rate
        lr_g = self.lr_g_factor*self.learning_rate
        print("lr_d", lr_d)
        print("lr_g", lr_g)
        opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
                                  list(self.decoder.parameters())+
                                  list(self.quantize.parameters())+
                                  list(self.quant_conv.parameters())+
                                  list(self.post_quant_conv.parameters()),
                                  lr=lr_g, betas=(0.5, 0.9))
        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
                                    lr=lr_d, betas=(0.5, 0.9))

        if self.scheduler_config is not None:
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                },
                {
                    'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                },
            ]
            return [opt_ae, opt_disc], scheduler
        return [opt_ae, opt_disc], []
示例#4
0
 def init_emb_stage_from_ckpt(self, config):
     if config is None:
         self.emb_stage_model = None
     else:
         model = instantiate_from_config(config)
         self.emb_stage_model = model
         if not self.emb_stage_trainable:
             self.emb_stage_model.eval()
示例#5
0
    def configure_optimizers(self):
        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.transformer.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed
        no_decay.add('pos_emb')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        extra_parameters = list()
        if self.emb_stage_trainable:
            extra_parameters += list(self.emb_stage_model.parameters())
        if hasattr(self, "merge_conv"):
            extra_parameters += list(self.merge_conv.parameters())
        else:
            assert self.merge_channels is None
        optim_groups.append({"params": extra_parameters, "weight_decay": 0.0})
        print(f"Optimizing {len(extra_parameters)} extra parameters.")
        optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
        if self.use_scheduler:
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }]
            return [optimizer], scheduler
        return optimizer
示例#6
0
    def __init__(
        self,
        transformer_config,
        first_stage_config,
        cond_stage_config,
        ckpt_path=None,
        ignore_keys=[],
        first_stage_key="image",
        cond_stage_key="depth",
        use_scheduler=False,
        scheduler_config=None,
        monitor="val/loss",
        downsample_cond_size=-1,
        pkeep=1.0,
        plot_cond_stage=False,
        log_det_sample=False,
        manipulate_latents=False,
        emb_stage_config=None,
        emb_stage_key="camera",
        emb_stage_trainable=True,
        top_k=None,
        unconditional=False,
        sos_token=0,
        use_first_stage_get_input=False,
    ):

        super().__init__()
        self.use_first_stage_get_input = use_first_stage_get_input
        if monitor is not None:
            self.monitor = monitor
        self.be_unconditional = unconditional
        self.sos_token = sos_token
        self.first_stage_key = first_stage_key
        self.log_det_sample = log_det_sample
        self.manipulate_latents = manipulate_latents
        self.init_first_stage_from_ckpt(first_stage_config)
        self.init_cond_stage_from_ckpt(cond_stage_config)
        self.transformer = instantiate_from_config(config=transformer_config)

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
        if not hasattr(self, "cond_stage_key"):
            self.cond_stage_key = cond_stage_key
        self.downsample_cond_size = downsample_cond_size
        self.pkeep = pkeep

        self.use_scheduler = use_scheduler
        if use_scheduler:
            assert scheduler_config is not None
            self.scheduler_config = scheduler_config
        self.plot_cond_stage = plot_cond_stage
        self.emb_stage_key = emb_stage_key
        self.emb_stage_trainable = emb_stage_trainable and emb_stage_config is not None
        if self.emb_stage_trainable:
            print("### TRAINING EMB STAGE!!!")
        self.init_emb_stage_from_ckpt(emb_stage_config)
        self.top_k = top_k if top_k is not None else 100
 def __init__(self, **kwargs):
     assert "n_unmasked" in kwargs and kwargs["n_unmasked"] != 0
     warper_config = kwargs.pop("warper_config")
     if warper_config.params is None:
         warper_config.params = dict()
     warper_config.params["n_unmasked"] = kwargs["n_unmasked"]
     warper_config.params["block_size"] = kwargs["block_size"]
     warper_config.params["n_embd"] = kwargs["n_embd"]
     super().__init__(**kwargs)
     self.warper = instantiate_from_config(warper_config)
示例#8
0
 def init_cond_stage_from_ckpt(self, config):
     if config == "__is_first_stage__":
         print("Using first stage also as cond stage.")
         self.cond_stage_model = self.first_stage_model
     elif config == "__is_unconditional__" or self.be_unconditional:
         print(
             f"Using no cond stage. Assuming the training is intended to be unconditional. "
             f"Prepending {self.sos_token} as a sos token.")
         self.be_unconditional = True
         self.cond_stage_key = self.first_stage_key
         self.cond_stage_model = SOSProvider(self.sos_token)
     else:
         model = instantiate_from_config(config)
         self.cond_stage_model = model.eval()
    def __init__(self,
                 transformer_config,
                 first_stage_config,
                 cond_stage_config,
                 ckpt_path=None,
                 ignore_keys=[],
                 first_stage_key="dst_img",
                 cond_stage_key="src_img",
                 use_scheduler=False,
                 scheduler_config=None,
                 monitor="val/loss",
                 pkeep=1.0,
                 plot_cond_stage=False,
                 log_det_sample=False,
                 top_k=None):

        super().__init__()
        if monitor is not None:
            self.monitor = monitor
        self.log_det_sample = log_det_sample
        self.init_first_stage_from_ckpt(first_stage_config)
        self.init_cond_stage_from_ckpt(cond_stage_config)
        self.transformer = instantiate_from_config(config=transformer_config)

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
        self.first_stage_key = first_stage_key
        self.cond_stage_key = cond_stage_key
        self.pkeep = pkeep

        self.use_scheduler = use_scheduler
        if use_scheduler:
            assert scheduler_config is not None
            self.scheduler_config = scheduler_config
        self.plot_cond_stage = plot_cond_stage
        self.top_k = top_k if top_k is not None else 100
        self.warpkwargs_keys = {
            "x": "src_img",
            "points": "src_points",
            "R": "R_rel",
            "t": "t_rel",
            "K_dst": "K",
            "K_src_inv": "K_inv",
        }
示例#10
0
 def init_depth_stage_from_ckpt(self, config):
     model = instantiate_from_config(config)
     self.depth_stage_model = model.eval()
     self.depth_stage_model.train = disabled_train
示例#11
0
    def __init__(self,
                 transformer_config,
                 first_stage_config,
                 cond_stage_config,
                 depth_stage_config,
                 merge_channels=None,
                 ckpt_path=None,
                 ignore_keys=[],
                 first_stage_key="image",
                 cond_stage_key="depth",
                 use_scheduler=False,
                 scheduler_config=None,
                 monitor="val/loss",
                 plot_cond_stage=False,
                 log_det_sample=False,
                 manipulate_latents=False,
                 emb_stage_config=None,
                 emb_stage_key="camera",
                 emb_stage_trainable=True,
                 top_k=None
                 ):

        super().__init__()
        if monitor is not None:
            self.monitor = monitor
        self.log_det_sample = log_det_sample
        self.manipulate_latents = manipulate_latents
        self.init_first_stage_from_ckpt(first_stage_config)
        self.init_cond_stage_from_ckpt(cond_stage_config)
        self.init_depth_stage_from_ckpt(depth_stage_config)
        self.transformer = instantiate_from_config(config=transformer_config)

        self.merge_channels = merge_channels
        if self.merge_channels is not None:
            self.merge_conv = torch.nn.Conv2d(self.merge_channels,
                                              self.transformer.config.n_embd,
                                              kernel_size=1,
                                              padding=0,
                                              bias=False)

        self.first_stage_key = first_stage_key
        self.cond_stage_key = cond_stage_key

        self.use_scheduler = use_scheduler
        if use_scheduler:
            assert scheduler_config is not None
            self.scheduler_config = scheduler_config
        self.plot_cond_stage = plot_cond_stage
        self.emb_stage_key = emb_stage_key
        self.emb_stage_trainable = emb_stage_trainable and emb_stage_config is not None
        if self.emb_stage_trainable:
            print("### TRAINING EMB STAGE!!!")
        self.init_emb_stage_from_ckpt(emb_stage_config)
        self.top_k = top_k if top_k is not None else 100

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

        self._midas = Midas()
        self._midas.eval()
        self._midas.train = disabled_train

        self.warpkwargs_keys = {
            "x": "src_img",
            "points": "src_points",
            "R": "R_rel",
            "t": "t_rel",
            "K_dst": "K",
            "K_src_inv": "K_inv",
        }
示例#12
0
 def init_first_stage_from_ckpt(self, config):
     model = instantiate_from_config(config)
     self.first_stage_model = model.eval()
    def __init__(self,
                 transformer_config,
                 first_stage_config,
                 cond_stage_config,
                 depth_stage_config,
                 merge_channels=None,
                 use_depth=True,
                 ckpt_path=None,
                 ignore_keys=[],
                 first_stage_key="image",
                 cond_stage_key="depth",
                 use_scheduler=False,
                 scheduler_config=None,
                 monitor="val/loss",
                 pkeep=1.0,
                 plot_cond_stage=False,
                 log_det_sample=False,
                 manipulate_latents=False,
                 emb_stage_config=None,
                 emb_stage_key="camera",
                 emb_stage_trainable=True,
                 top_p=None,
                 top_k=None
                 ):

        super().__init__()
        if monitor is not None:
            self.monitor = monitor
        self.log_det_sample = log_det_sample
        self.manipulate_latents = manipulate_latents
        self.init_first_stage_from_ckpt(first_stage_config)
        self.init_cond_stage_from_ckpt(cond_stage_config)
        self.init_depth_stage_from_ckpt(depth_stage_config)
        self.transformer = instantiate_from_config(config=transformer_config)

        self.merge_channels = merge_channels
        if self.merge_channels is not None:
            self.merge_conv = torch.nn.Conv2d(self.merge_channels,
                                              self.transformer.config.n_embd,
                                              kernel_size=1,
                                              padding=0,
                                              bias=False)

        self.use_depth = use_depth
        if not self.use_depth:
            assert self.merge_channels is None

        self.first_stage_key = first_stage_key
        self.cond_stage_key = cond_stage_key

        self.use_scheduler = use_scheduler
        if use_scheduler:
            assert scheduler_config is not None
            self.scheduler_config = scheduler_config
        self.plot_cond_stage = plot_cond_stage
        self.emb_stage_key = emb_stage_key
        self.emb_stage_trainable = emb_stage_trainable and emb_stage_config is not None
        self.init_emb_stage_from_ckpt(emb_stage_config)
        self.top_p = top_p if top_p is not None else 0.95
        try:
            tk = self.first_stage_model.quantize.n_e
        except:
            tk = 100
        self.top_k = top_k if top_k is not None else tk

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

        self._midas = Midas()
        self._midas.eval()
        self._midas.train = disabled_train