def test_setup_overwrite(self, modeldir): test_folder = os.path.join(modeldir, "testing") os.makedirs(test_folder) args = Namespace(mode="train", modelpath=modeldir, overwrite=True, seed=20) train_args = setup_run(args) assert not os.path.exists(test_folder) args = Namespace(mode="eval", modelpath=modeldir, seed=20) assert train_args == setup_run(args)
def main(args): # setup train_args = setup_run(args) device = torch.device("cuda" if args.cuda else "cpu") # get dataset environment_provider = get_environment_provider(train_args, device=device) dataset = get_dataset(train_args, environment_provider=environment_provider) # get dataloaders split_path = os.path.join(args.modelpath, "split.npz") train_loader, val_loader, test_loader = get_loaders(args, dataset=dataset, split_path=split_path, logging=logging) # define metrics metrics = get_metrics(train_args) # train or evaluate if args.mode == "train": # get statistics atomref = dataset.get_atomref(args.property) mean, stddev = get_statistics( args=args, split_path=split_path, train_loader=train_loader, atomref=atomref, divide_by_atoms=get_divide_by_atoms(args), logging=logging, ) # build model model = get_model(args, train_loader, mean, stddev, atomref, logging=logging) # build trainer logging.info("training...") trainer = get_trainer(args, model, train_loader, val_loader, metrics) # run training trainer.train(device, n_epochs=args.n_epochs) logging.info("...training done!") else: raise ("Use the original SchnetPack script instead.")
def main(args): # setup train_args = setup_run(args) device = torch.device("cuda" if args.cuda else "cpu") # get dataset environment_provider = get_environment_provider(train_args, device=device) dataset = get_dataset(train_args, environment_provider=environment_provider) # get dataloaders split_path = os.path.join(args.modelpath, "split.npz") train_loader, val_loader, test_loader = get_loaders(args, dataset=dataset, split_path=split_path, logging=logging) # define metrics metrics = get_metrics(train_args) # train or evaluate if args.mode == "train": # get statistics atomref = dataset.get_atomref(args.property) mean, stddev = get_statistics( args=args, split_path=split_path, train_loader=train_loader, atomref=atomref, divide_by_atoms=get_divide_by_atoms(args), logging=logging, ) # build model model = get_model(args, train_loader, mean, stddev, atomref, logging=logging) # build trainer logging.info("training...") trainer = get_trainer(args, model, train_loader, val_loader, metrics) # run training trainer.train(device, n_epochs=args.n_epochs) logging.info("...training done!") elif args.mode == "eval": # remove old evaluation files evaluation_fp = os.path.join(args.modelpath, "evaluation.txt") if os.path.exists(evaluation_fp): if args.overwrite: os.remove(evaluation_fp) else: raise ScriptError( "The evaluation file does already exist at {}! Add overwrite flag" " to remove.".format(evaluation_fp)) # load model logging.info("loading trained model...") model = torch.load(os.path.join(args.modelpath, "best_model")) # run evaluation logging.info("evaluating...") if spk.utils.get_derivative(train_args) is None: with torch.no_grad(): evaluate( args, model, train_loader, val_loader, test_loader, device, metrics=metrics, ) else: evaluate( args, model, train_loader, val_loader, test_loader, device, metrics=metrics, ) logging.info("... evaluation done!") else: raise ScriptError("Unknown mode: {}".format(args.mode))
args, model, train_loader, val_loader, test_loader, device, metrics=metrics, ) else: evaluate( args, model, train_loader, val_loader, test_loader, device, metrics=metrics, ) logging.info("... evaluation done!") else: raise ScriptError("Unknown mode: {}".format(args.mode)) if __name__ == "__main__": parser = build_parser() args = parser.parse_args() args = setup_run(args) print("*", args) main(args)