示例#1
0
    def test_eval(self, qm9_train_loader, qm9_val_loader, qm9_test_loader,
                  modeldir):
        args = Namespace(
            model="schnet",
            cutoff_function="hard",
            features=100,
            n_filters=5,
            interactions=2,
            cutoff=4.0,
            num_gaussians=30,
            modelpath=modeldir,
            split="test",
        )
        repr = get_representation(args)
        output_module = spk.atomistic.Atomwise(args.features,
                                               property="energy_U0")
        model = get_model(repr, output_module)

        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0")
            ],
        )
        assert os.path.exists(os.path.join(modeldir, "evaluation.txt"))
        args.split = "train"
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0")
            ],
        )
        args.split = "val"
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0")
            ],
        )
示例#2
0
 def test_schnet(self):
     args = Namespace(
         model="schnet",
         cutoff_function="hard",
         features=100,
         n_filters=5,
         interactions=2,
         cutoff=4.0,
         num_gaussians=30,
         normalize_filter=False,
     )
     repr = get_representation(args)
     assert type(repr) == spk.SchNet
     assert len(repr.interactions) == 2
     assert type(repr) != spk.representation.BehlerSFBlock
     args = Namespace(
         model="schnet",
         cutoff_function="cosine",
         features=100,
         n_filters=5,
         interactions=2,
         cutoff=4.0,
         num_gaussians=30,
         normalize_filter=False,
     )
     repr = get_representation(args)
     assert type(repr) == spk.SchNet
     assert len(repr.interactions) == 2
     assert type(repr) != spk.representation.BehlerSFBlock
     args = Namespace(
         model="schnet",
         cutoff_function="mollifier",
         features=100,
         n_filters=5,
         interactions=2,
         cutoff=4.0,
         num_gaussians=30,
         normalize_filter=False,
     )
     repr = get_representation(args)
     assert type(repr) == spk.SchNet
     assert len(repr.interactions) == 2
     assert type(repr) != spk.representation.BehlerSFBlock
示例#3
0
def get_model(args, train_loader, mean, stddev, atomref, logging=None):
    if args.mode == "train":
        if logging:
            logging.info("building model...")
        if args.dropout == 0 and args.n_layers == 2:
            from schnetpack.utils import get_representation, get_output_module

            representation = get_representation(args, train_loader)
            output_module = get_output_module(
                args,
                representation=representation,
                mean=mean,
                stddev=stddev,
                atomref=atomref,
            )
        else:
            from schnetpack.utils import get_representation  #get_output_module# get_representation

            representation = get_representation(args, train_loader)
            #            representation = get_rep_with_dropout(args, train_loader)
            #            output_module = get_output_module(
            output_module = get_outmod_with_dropout(
                args,
                representation=representation,
                mean=mean,
                stddev=stddev,
                atomref=atomref,
            )
        model = spk.AtomisticModel(representation, [output_module])

        if args.parallel:
            model = nn.DataParallel(model)
        if logging:
            logging.info("The model you built has: %d parameters" %
                         spk.utils.count_params(model))
        return model
    else:
        raise spk.utils.ScriptError("Invalid mode selected: {}".format(
            args.mode))
示例#4
0
 def test_wacsf(self, qm9_train_loader):
     args = Namespace(
         model="wacsf",
         cutoff_function="cosine",
         features=100,
         n_filters=5,
         interactions=3,
         cutoff=4.0,
         num_gaussians=30,
         behler=False,
         elements=["C"],
         radial=22,
         angular=5,
         zetas=[1],
         centered=True,
         crossterms=True,
         standardize=False,
         cuda=False,
     )
     repr = get_representation(args, qm9_train_loader)
     assert type(repr) != spk.SchNet
     assert type(repr) == spk.representation.BehlerSFBlock