Exemple #1
0
    def test_eval(self, qm9_train_loader, qm9_val_loader, qm9_test_loader,
                  modeldir):
        args = Namespace(
            model="schnet",
            cutoff_function="hard",
            features=100,
            n_filters=5,
            interactions=2,
            cutoff=4.0,
            num_gaussians=30,
            modelpath=modeldir,
            split="test",
        )
        repr = get_representation(args)
        output_module = spk.atomistic.Atomwise(args.features,
                                               property="energy_U0")
        model = get_model(repr, output_module)

        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0")
            ],
        )
        assert os.path.exists(os.path.join(modeldir, "evaluation.txt"))
        args.split = "train"
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0")
            ],
        )
        args.split = "val"
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0")
            ],
        )
def main(args):

    # setup
    #    train_args = setup_run(args)
    logging.info("CUDA is used: " + str(args.cuda))
    if args.cuda:
        logging.info("CUDA is available: " + str(torch.cuda.is_available()))

    device = torch.device("cuda" if args.cuda else "cpu")

    # get dataset
    dataset = get_dataset(args)

    # get dataloaders
    split_path = os.path.join(args.modelpath, "split.npz")
    train_loader, val_loader, test_loader = get_loaders(args,
                                                        dataset=dataset,
                                                        split_path=split_path,
                                                        logging=logging)

    # define metrics
    metrics = get_metrics(args)

    # train or evaluate
    if args.mode == "train":

        # get statistics
        atomref = dataset.get_atomref(args.property)
        divide_by_atoms = settings.divide_by_atoms[args.property]
        mean, stddev = get_statistics(
            args=args,
            split_path=split_path,
            train_loader=train_loader,
            atomref=atomref,
            divide_by_atoms=divide_by_atoms,
            logging=logging,
        )
        aggregation_mode = settings.pooling_mode[args.property]

        # build model
        model = get_model(args,
                          train_loader,
                          mean,
                          stddev,
                          atomref,
                          aggregation_mode,
                          logging=logging)

        # build trainer
        logging.info("training...")
        trainer = get_trainer(args, model, train_loader, val_loader, metrics)

        # run training
        trainer.train(device, n_epochs=args.n_epochs)
        logging.info("...training done!")

    elif args.mode == "eval":

        # remove old evaluation files
        evaluation_fp = os.path.join(args.modelpath, "evaluation.txt")
        if os.path.exists(evaluation_fp):
            if args.overwrite:
                os.remove(evaluation_fp)
            else:
                raise ScriptError(
                    "The evaluation file does already exist at {}! Add overwrite flag"
                    " to remove.".format(evaluation_fp))

        # load model
        logging.info("loading trained model...")
        model = torch.load(os.path.join(args.modelpath, "best_model"))

        # run evaluation
        logging.info("evaluating...")
        if args.dataset != "md17":
            with torch.no_grad():
                evaluate(
                    args,
                    model,
                    train_loader,
                    val_loader,
                    test_loader,
                    device,
                    metrics=metrics,
                )
        else:
            evaluate(
                args,
                model,
                train_loader,
                val_loader,
                test_loader,
                device,
                metrics=metrics,
            )
        logging.info("... evaluation done!")

    else:
        raise ScriptError("Unknown mode: {}".format(args.mode))
Exemple #3
0
def main(args):

    # setup
    train_args = setup_run(args)

    device = torch.device("cuda" if args.cuda else "cpu")

    # get dataset
    environment_provider = get_environment_provider(train_args, device=device)
    dataset = get_dataset(train_args,
                          environment_provider=environment_provider)

    # get dataloaders
    split_path = os.path.join(args.modelpath, "split.npz")
    train_loader, val_loader, test_loader = get_loaders(args,
                                                        dataset=dataset,
                                                        split_path=split_path,
                                                        logging=logging)

    # define metrics
    metrics = get_metrics(train_args)

    # train or evaluate
    if args.mode == "train":

        # get statistics
        atomref = dataset.get_atomref(args.property)
        mean, stddev = get_statistics(
            args=args,
            split_path=split_path,
            train_loader=train_loader,
            atomref=atomref,
            divide_by_atoms=get_divide_by_atoms(args),
            logging=logging,
        )

        # build model
        model = get_model(args,
                          train_loader,
                          mean,
                          stddev,
                          atomref,
                          logging=logging)

        # build trainer
        logging.info("training...")
        trainer = get_trainer(args, model, train_loader, val_loader, metrics)

        # run training
        trainer.train(device, n_epochs=args.n_epochs)
        logging.info("...training done!")

    elif args.mode == "eval":

        # remove old evaluation files
        evaluation_fp = os.path.join(args.modelpath, "evaluation.txt")
        if os.path.exists(evaluation_fp):
            if args.overwrite:
                os.remove(evaluation_fp)
            else:
                raise ScriptError(
                    "The evaluation file does already exist at {}! Add overwrite flag"
                    " to remove.".format(evaluation_fp))

        # load model
        logging.info("loading trained model...")
        model = torch.load(os.path.join(args.modelpath, "best_model"))

        # run evaluation
        logging.info("evaluating...")
        if spk.utils.get_derivative(train_args) is None:
            with torch.no_grad():
                evaluate(
                    args,
                    model,
                    train_loader,
                    val_loader,
                    test_loader,
                    device,
                    metrics=metrics,
                )
        else:
            evaluate(
                args,
                model,
                train_loader,
                val_loader,
                test_loader,
                device,
                metrics=metrics,
            )
        logging.info("... evaluation done!")

    else:
        raise ScriptError("Unknown mode: {}".format(args.mode))
    def test_eval(self, qm9_train_loader, qm9_val_loader, qm9_test_loader,
                  modeldir):
        args = Namespace(
            mode="train",
            model="schnet",
            cutoff_function="hard",
            features=100,
            n_filters=5,
            interactions=2,
            cutoff=4.0,
            num_gaussians=30,
            modelpath=modeldir,
            split=["test"],
            property="energy_U0",
            dataset="qm9",
            parallel=False,
        )
        mean = {args.property: None}
        model = get_model(args,
                          train_loader=qm9_train_loader,
                          mean=mean,
                          stddev=mean,
                          atomref=mean)

        os.makedirs(modeldir, exist_ok=True)
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0")
            ],
        )
        assert os.path.exists(os.path.join(modeldir, "evaluation.txt"))
        args.split = ["train"]
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0")
            ],
        )
        args.split = ["validation"]
        evaluate(
            args,
            model,
            qm9_train_loader,
            qm9_val_loader,
            qm9_test_loader,
            "cpu",
            metrics=[
                schnetpack.train.metrics.MeanAbsoluteError(
                    "energy_U0", model_output="energy_U0")
            ],
        )