def initialize_domain_config(self): model_config = self.model_config["model_config"] optimizer_config = self.model_config["optimizer_config"] recon_loss_config = self.model_config["loss_config"] self.domain_config = get_domain_configuration( name=self.domain_name, model_dict=model_config, data_loader_dict=self.data_loader_dict, data_key=self.data_key, label_key=self.label_key, optimizer_dict=optimizer_config, recon_loss_fct_dict=recon_loss_config, )
def initialize_seq_domain_config(self): if self.domain_configs is None: self.domain_configs = [] model_config = self.seq_model_config["model_config"] optimizer_config = self.seq_model_config["optimizer_config"] recon_loss_config = self.seq_model_config["loss_config"] seq_domain_config = get_domain_configuration( name="rna", model_dict=model_config, data_loader_dict=self.seq_data_loader_dict, data_key=self.seq_data_key, label_key=self.seq_label_key, optimizer_dict=optimizer_config, recon_loss_fct_dict=recon_loss_config, ) self.domain_configs.append(seq_domain_config)
def initialize_image_domain_config(self, train_model: bool = True): if self.domain_configs is None: self.domain_configs = [] model_config = self.image_model_config["model_config"] optimizer_config = self.image_model_config["optimizer_config"] recon_loss_config = self.image_model_config["loss_config"] image_domain_config = get_domain_configuration( name="image", model_dict=model_config, data_loader_dict=self.image_data_loader_dict, data_key=self.image_data_key, label_key=self.image_label_key, optimizer_dict=optimizer_config, recon_loss_fct_dict=recon_loss_config, train_model=train_model, ) self.domain_configs.append(image_domain_config)
def initialize_seq_domain_config_2(self, name: str = "ATAC", train_model: bool = True): if self.domain_configs is None: self.domain_configs = [] model_config = self.seq_model_config_2["model_config"] optimizer_config = self.seq_model_config_2["optimizer_config"] recon_loss_config = self.seq_model_config_2["loss_config"] seq_domain_config = get_domain_configuration( name=name, model_dict=model_config, data_loader_dict=self.seq_data_loader_dict_2, data_key=self.seq_data_key_2, label_key=self.seq_label_key_2, optimizer_dict=optimizer_config, recon_loss_fct_dict=recon_loss_config, train_model=train_model, ) self.domain_configs.append(seq_domain_config)
def initialize_seq_domain_config_1(self, name: str = "RNA", train_model: bool = True): if self.domain_configs is None: self.domain_configs = [] model_config = self.seq_model_config_1["model_config"] optimizer_config = self.seq_model_config_1["optimizer_config"] recon_loss_config = self.seq_model_config_1["loss_config"] seq_domain_config = get_domain_configuration( name=name, model_dict=model_config, data_loader_dict=None, data_key=self.seq_data_key_1, label_key=self.seq_label_key_1, optimizer_dict=optimizer_config, recon_loss_fct_dict=recon_loss_config, train_model=train_model, ) self.domain_configs.append(seq_domain_config) self.initial_seq_model_i_weights = copy.deepcopy( seq_domain_config.domain_model_config.model.state_dict())