示例#1
0
 def __init__(self, ae_config, loss_config, ckpt_path=None, ignore_keys=[]):
     super().__init__()
     self.autoencoder = instantiate_from_config(ae_config)
     self.loss = instantiate_from_config(loss_config)
     if ckpt_path is not None:
         print("Loading model from {}".format(ckpt_path))
         self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
示例#2
0
def get_data(config):
    # get data
    data = instantiate_from_config(config.data)
    data.prepare_data()
    data.setup()
    dset = data.datasets["validation"]
    return dset
示例#3
0
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
    model = instantiate_from_config(config)
    if sd is not None:
        model.load_state_dict(sd)
    if gpu:
        model.cuda()
    if eval_mode:
        model.eval()
    return {"model": model}
示例#4
0
 def init_preprocessing(self):
     dqcfg = {
         "target":
         "net2net.modules.autoencoder.basic.BasicFullyConnectedVAE"
     }
     self.dequantizer = instantiate_from_config(dqcfg)
     ckpt = get_ckpt_path("dequant_vae",
                          "net2net/modules/autoencoder/dequant_vae")
     self.dequantizer.load_state_dict(torch.load(
         ckpt, map_location=torch.device("cpu")),
                                      strict=False)
     self.dequantizer.eval()
     self.dequantizer.train = disabled_train
示例#5
0
 def __init__(self,
              flow_config,
              first_stage_config,
              cond_stage_config,
              ckpt_path=None,
              ignore_keys=[],
              first_stage_key="image",
              cond_stage_key="image",
              interpolate_cond_size=-1):
     super().__init__()
     self.init_first_stage_from_ckpt(first_stage_config)
     self.init_cond_stage_from_ckpt(cond_stage_config)
     self.flow = instantiate_from_config(config=flow_config)
     self.loss = NLL()
     self.first_stage_key = first_stage_key
     self.cond_stage_key = cond_stage_key
     self.interpolate_cond_size = interpolate_cond_size
     if ckpt_path is not None:
         self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
示例#6
0
 def init_cond_stage_from_ckpt(self, config):
     model = instantiate_from_config(config)
     self.cond_stage_model = model.eval()
示例#7
0
 def __init__(self, flow_config):
     super().__init__()
     self.flow = instantiate_from_config(config=flow_config)
     self.loss = NLL()
示例#8
0
 def init_first_stage_from_ckpt(self, config):
     model = instantiate_from_config(config)
     model = model.eval()
     model.train = disabled_train
     self.first_stage_model = model
示例#9
0
 def init_to_c_model(self, config):
     model = instantiate_from_config(config)
     model = model.eval()
     model.train = disabled_train
     self.to_c_model = model