Ejemplo n.º 1
0
def _code(
    model,
    num_heads,
    num_bases,
    aggrs,
    hidden,
    use_old_code_dataset,
):
    code_data(data_location(), use_old_code_dataset=use_old_code_dataset)
    if model in ["egc", "egc-softmax"]:
        return EgcCodeConfig(
            hidden=hidden,
            num_bases=num_bases,
            num_heads=num_heads,
            softmax="softmax" in model,
            aggrs=aggrs,
            use_old_code_dataset=use_old_code_dataset,
        )

    if model == "gcn":
        return GcnCodeConfig(hidden, use_old_code_dataset=use_old_code_dataset)
    elif model == "gat":
        return GatCodeConfig(hidden, use_old_code_dataset=use_old_code_dataset)
    elif model == "gin":
        return GinCodeConfig(hidden, use_old_code_dataset=use_old_code_dataset)
    elif model in ["mpnn-sum", "mpnn-max"]:
        return MpnnCodeConfig(
            hidden,
            aggr="add" if "sum" in model else "max",
            use_old_code_dataset=use_old_code_dataset,
        )
    elif model == "pna":
        return PnaCodeConfig(hidden, use_old_code_dataset=use_old_code_dataset)
    else:
        raise ValueError
Ejemplo n.º 2
0
def _zinc(model, num_samples, egc_num_bases, egc_num_heads, aggrs, hidden):
    zinc_data(data_location())

    if model in ["egc", "egc-softmax"]:
        config = ZincEgcConfig(
            num_samples=num_samples,
            softmax="softmax" in model,
            num_bases=egc_num_bases,
            num_heads=egc_num_heads,
            aggrs=aggrs,
            hidden=hidden,
        )
    else:
        raise ValueError

    return config
Ejemplo n.º 3
0
def _arxiv(model, num_heads, num_bases, aggrs, hidden):
    arxiv_data(data_location())

    if model == "gcn":
        return GcnArxivConfig(hidden)
    elif model == "gat":
        return GatArxivConfig(hidden)
    elif model == "gin":
        return GinArxivConfig(hidden)
    elif model in ["egc", "egc-softmax"]:
        return EgcArxivConfig(
            num_heads=num_heads,
            num_bases=num_bases,
            softmax="softmax" in model,
            aggrs=aggrs,
            hidden=hidden,
        )
    elif model in ["mpnn-sum", "mpnn-max"]:
        return MpnnArxivConfig(hidden, aggr="add" if "sum" in model else "max")
    elif model == "pna":
        return PnaArxivConfig(hidden)
    else:
        raise ValueError
Ejemplo n.º 4
0
def _mol(model, dataset, num_heads, num_bases, aggrs, hidden):
    mol_data(data_location(), dataset)
    if model in ["egc", "egc-softmax"]:
        return EgcMolConfig(
            dataset,
            hidden=hidden,
            num_bases=num_bases,
            num_heads=num_heads,
            softmax="softmax" in model,
            aggrs=aggrs,
        )

    elif model == "gcn":
        return GcnMolConfig(dataset, hidden)
    elif model == "gat":
        return GatMolConfig(dataset, hidden)
    elif model == "gin":
        return GinMolConfig(dataset, hidden)
    elif model in ["mpnn-sum", "mpnn-max"]:
        return MpnnMolConfig(dataset,
                             hidden,
                             aggr="add" if "sum" in model else "max")
    else:
        raise ValueError
Ejemplo n.º 5
0
 def data(self, pinned_objs, hparams):
     return mol_data(
         data_location(), dataset=self.dataset, batch_size=hparams["batch_size"]
     )
Ejemplo n.º 6
0
 def data(self, pinned_objs, hparams):
     return cifar_data(data_location(), batch_size=hparams["batch_size"])
Ejemplo n.º 7
0
 def data(self, pinned_objs, hparams):
     return arxiv_data(data_location())
Ejemplo n.º 8
0
 def data(self, pinned_objs, hparams):
     return code_data(
         data_location(),
         batch_size=128,
         use_old_code_dataset=self.use_old_code_dataset,
     )