コード例 #1
0
 def pretrained(self, model_dir):
     return load_pretrained(
         self,
         dataset_name="hiv",
         model_name=f"mpnn_{self.aggr}",
         hidden=self.hidden,
         model_dir=model_dir,
         pretrained_conf=PRETRAINED_CONF,
     )
コード例 #2
0
ファイル: configs.py プロジェクト: jingmouren/egc
 def pretrained(self, model_dir):
     return load_pretrained(
         self,
         dataset_name="arxiv",
         model_name="pna",
         hidden=self.hidden,
         model_dir=model_dir,
         pretrained_conf=PRETRAINED_CONF,
     )
コード例 #3
0
 def pretrained(self, model_dir):
     assert not self.use_old_code_dataset
     return load_pretrained(
         self,
         dataset_name="code2",
         model_name="pna",
         hidden=self.hidden,
         model_dir=model_dir,
         pretrained_conf=PRETRAINED_CONF,
     )
コード例 #4
0
    def pretrained(self, model_dir):
        assert not self.softmax
        if len(self.aggrs) == 1:
            assert "symadd" in self.aggrs
            assert self.hidden == 236 and self.num_heads == 4 and self.num_bases == 4
            model = "egc_s"
        elif len(self.aggrs) == 3:
            assert set(self.aggrs).issuperset({"add", "max", "mean"})
            assert self.hidden == 224 and self.num_heads == 4 and self.num_bases == 4
            model = "egc_m"
        else:
            raise ValueError

        return load_pretrained(
            self,
            dataset_name="hiv",
            model_name=model,
            hidden=self.hidden,
            model_dir=model_dir,
            pretrained_conf=PRETRAINED_CONF,
        )