示例#1
0
    def initialize_domain_config(self):
        model_config = self.model_config["model_config"]
        optimizer_config = self.model_config["optimizer_config"]
        recon_loss_config = self.model_config["loss_config"]

        self.domain_config = get_domain_configuration(
            name=self.domain_name,
            model_dict=model_config,
            data_loader_dict=self.data_loader_dict,
            data_key=self.data_key,
            label_key=self.label_key,
            optimizer_dict=optimizer_config,
            recon_loss_fct_dict=recon_loss_config,
        )
示例#2
0
    def initialize_seq_domain_config(self):
        if self.domain_configs is None:
            self.domain_configs = []

        model_config = self.seq_model_config["model_config"]
        optimizer_config = self.seq_model_config["optimizer_config"]
        recon_loss_config = self.seq_model_config["loss_config"]

        seq_domain_config = get_domain_configuration(
            name="rna",
            model_dict=model_config,
            data_loader_dict=self.seq_data_loader_dict,
            data_key=self.seq_data_key,
            label_key=self.seq_label_key,
            optimizer_dict=optimizer_config,
            recon_loss_fct_dict=recon_loss_config,
        )
        self.domain_configs.append(seq_domain_config)
示例#3
0
    def initialize_image_domain_config(self, train_model: bool = True):
        if self.domain_configs is None:
            self.domain_configs = []

        model_config = self.image_model_config["model_config"]
        optimizer_config = self.image_model_config["optimizer_config"]
        recon_loss_config = self.image_model_config["loss_config"]

        image_domain_config = get_domain_configuration(
            name="image",
            model_dict=model_config,
            data_loader_dict=self.image_data_loader_dict,
            data_key=self.image_data_key,
            label_key=self.image_label_key,
            optimizer_dict=optimizer_config,
            recon_loss_fct_dict=recon_loss_config,
            train_model=train_model,
        )
        self.domain_configs.append(image_domain_config)
    def initialize_seq_domain_config_2(self,
                                       name: str = "ATAC",
                                       train_model: bool = True):
        if self.domain_configs is None:
            self.domain_configs = []

        model_config = self.seq_model_config_2["model_config"]
        optimizer_config = self.seq_model_config_2["optimizer_config"]
        recon_loss_config = self.seq_model_config_2["loss_config"]

        seq_domain_config = get_domain_configuration(
            name=name,
            model_dict=model_config,
            data_loader_dict=self.seq_data_loader_dict_2,
            data_key=self.seq_data_key_2,
            label_key=self.seq_label_key_2,
            optimizer_dict=optimizer_config,
            recon_loss_fct_dict=recon_loss_config,
            train_model=train_model,
        )
        self.domain_configs.append(seq_domain_config)
    def initialize_seq_domain_config_1(self,
                                       name: str = "RNA",
                                       train_model: bool = True):
        if self.domain_configs is None:
            self.domain_configs = []

        model_config = self.seq_model_config_1["model_config"]
        optimizer_config = self.seq_model_config_1["optimizer_config"]
        recon_loss_config = self.seq_model_config_1["loss_config"]

        seq_domain_config = get_domain_configuration(
            name=name,
            model_dict=model_config,
            data_loader_dict=None,
            data_key=self.seq_data_key_1,
            label_key=self.seq_label_key_1,
            optimizer_dict=optimizer_config,
            recon_loss_fct_dict=recon_loss_config,
            train_model=train_model,
        )
        self.domain_configs.append(seq_domain_config)
        self.initial_seq_model_i_weights = copy.deepcopy(
            seq_domain_config.domain_model_config.model.state_dict())