Esempio n. 1
0
def build_fn(atom_features: int = 128,
             message_steps: int = 8,
             atomwise: bool = False,
             output_layers: int = 3,
             reduce_fn: str = 'sum',
             mean: Optional[float] = None,
             std: Optional[float] = None):
    schnet = spk.representation.SchNet(
        n_atom_basis=atom_features,
        n_filters=atom_features,
        n_gaussians=20,
        n_interactions=message_steps,
        cutoff=4.,
        cutoff_network=spk.nn.cutoff.CosineCutoff)

    if atomwise:
        output = Atomwise(n_in=atom_features,
                          n_layers=output_layers,
                          aggregation_mode=reduce_fn,
                          mean=mean,
                          stddev=std,
                          property='ip')
    else:
        output = Moleculewise(n_in=atom_features,
                              n_layers=output_layers,
                              aggregation_mode=reduce_fn,
                              mean=mean,
                              stddev=std,
                              property='ip')
    return spk.AtomisticModel(representation=schnet, output_modules=output)
Esempio n. 2
0
def test_charge_correction(schnet_batch, n_atom_basis):
    """
    Test if charge correction yields the desired total charges.

    """
    model = spk.AtomisticModel(
        spk.SchNet(n_atom_basis),
        spk.atomistic.DipoleMoment(
            n_atom_basis, charge_correction="q", contributions="q"
        ),
    )
    q = torch.randint(0, 10, (schnet_batch["_positions"].shape[0], 1))
    schnet_batch.update(q=q)

    q_i = model(schnet_batch)["q"]

    assert torch.allclose(q.float(), q_i.sum(1), atol=1e-6)
def get_model(args, train_loader, mean, stddev, atomref, logging=logging):
    if logging:
        logging.info("building model...")
    representation = get_representation(args, train_loader)
    output_module = get_output_module(
        args,
        representation=representation,
        mean=mean,
        stddev=stddev,
        atomref=atomref,
    )
    model = spk.AtomisticModel(representation, [output_module])

    if args.parallel:
        model = nn.DataParallel(model)
    if logging:
        logging.info("The model you built has: %d parameters" %
                     spk.utils.count_params(model))
    return model
Esempio n. 4
0
def get_model(args, train_loader, mean, stddev, atomref, logging=None):
    if args.mode == "train":
        if logging:
            logging.info("building model...")
        if args.dropout == 0 and args.n_layers == 2:
            from schnetpack.utils import get_representation, get_output_module

            representation = get_representation(args, train_loader)
            output_module = get_output_module(
                args,
                representation=representation,
                mean=mean,
                stddev=stddev,
                atomref=atomref,
            )
        else:
            from schnetpack.utils import get_representation  #get_output_module# get_representation

            representation = get_representation(args, train_loader)
            #            representation = get_rep_with_dropout(args, train_loader)
            #            output_module = get_output_module(
            output_module = get_outmod_with_dropout(
                args,
                representation=representation,
                mean=mean,
                stddev=stddev,
                atomref=atomref,
            )
        model = spk.AtomisticModel(representation, [output_module])

        if args.parallel:
            model = nn.DataParallel(model)
        if logging:
            logging.info("The model you built has: %d parameters" %
                         spk.utils.count_params(model))
        return model
    else:
        raise spk.utils.ScriptError("Invalid mode selected: {}".format(
            args.mode))
Esempio n. 5
0
def get_model(args, train_loader, mean, stddev, atomref, logging=None):
    """
    Build a model from selected parameters or load trained model for evaluation.

    Args:
        args (argsparse.Namespace): Script arguments
        train_loader (spk.AtomsLoader): loader for training data
        mean (torch.Tensor): mean of training data
        stddev (torch.Tensor): stddev of training data
        atomref (dict): atomic references
        logging: logger

    Returns:
        spk.AtomisticModel: model for training or evaluation
    """
    if args.mode == "train":
        if logging:
            logging.info("building model...")
        representation = get_representation(args, train_loader)
        output_module = get_output_module(
            args,
            representation=representation,
            mean=mean,
            stddev=stddev,
            atomref=atomref,
        )
        model = spk.AtomisticModel(representation, [output_module])

        if args.parallel:
            model = nn.DataParallel(model)
        if logging:
            logging.info(
                "The model you built has: %d parameters" % spk.utils.count_params(model)
            )
        return model
    else:
        raise spk.utils.ScriptError("Invalid mode selected: {}".format(args.mode))
Esempio n. 6
0
def model(args, omdData, atomrefs, means, stddevs):

    schnet = spk.representation.SchNet(
        n_atom_basis=args.features,
        n_filters=args.features,
        n_gaussians=50,
        n_interactions=6,
        cutoff=5.0,
        cutoff_network=spk.nn.cutoff.CosineCutoff)
    output_module = get_output_module(
        args,
        representation=schnet,
        mean=means,
        stddev=stddevs,
        atomref=atomrefs,
    )

    # output_Bgap = spk.atomistic.Atomwise(n_in=args.features, atomref=atomrefs[OrganicMaterialsDatabase.BandGap], property=OrganicMaterialsDatabase.BandGap,
    # 						   mean=means[OrganicMaterialsDatabase.BandGap], stddev=stddevs[OrganicMaterialsDatabase.BandGap])
    model = spk.AtomisticModel(representation=schnet,
                               output_modules=output_module)
    if args.parallel:
        model = nn.DataParallel(model)
    return model
Esempio n. 7
0
means, stddevs = train_loader.get_statistics(
    properties, per_atom=True, atomrefs=atomrefs
)

# model build
logging.info("build model")
representation = spk.SchNet(n_interactions=6)
output_modules = [
    spk.Atomwise(
        property=QM9.U0,
        mean=means[QM9.U0],
        stddev=stddevs[QM9.U0],
        atomref=atomrefs[QM9.U0],
    )
]
model = spk.AtomisticModel(representation, output_modules)

# build optimizer
optimizer = Adam(model.parameters(), lr=1e-4)

# hooks
logging.info("build trainer")
metrics = [MeanAbsoluteError(p, p) for p in properties]
hooks = [CSVHook(log_path=model_dir, metrics=metrics), ReduceLROnPlateauHook(optimizer)]

# trainer
loss = mse_loss(properties)
trainer = Trainer(
    model_dir,
    model=model,
    hooks=hooks,
Esempio n. 8
0
# Call ShiftedSigmoid for Activation -> You can use, but you do not have to
act = ShiftedSigmoid()

# Create Model and Optimizer -> Please Note, that here a modified Fork from the original schnetpack is used
# This allows a Noise on the Positions, also the reducing Featurevektor in interaction-layer is part of the modification
model = schnetpack.representation.SchNet(use_noise=False, noise_mean=0.0, noise_std=0.1, chargeEmbedding = True,
                                         ownFeatures = False, nFeatures = 8, finalFeature = None,
                                         max_z=200, n_atom_basis=20, n_filters=[32, 24, 16, 8, 4], n_gaussians=25,
                                         normalize_filter=False, coupled_interactions=False, trainable_gaussians=False,
                                         n_interactions=5, cutoff=2.5,
                                         cutoff_network=schnetpack.nn.cutoff.CosineCutoff)
# Modification of schnetpack allows an activation-function on the output of the output-network
d = schnetpack.atomistic.Atomwise(n_in=20, aggregation_mode='avg',
                                  n_layers=4, mode='postaggregate') # , activation=F.relu)#, output_activation=act.forward)
model = schnetpack.AtomisticModel(model, d)
optimizer = Adam(model.parameters())


# This function counts the number of trainable parameters of the network
# Taken from https://stackoverflow.com/questions/48393608/pytorch-network-parameter-calculation (Wasi Ahmad https://stackoverflow.com/users/5352399/wasi-ahmad)
def count_parameters(model):
    paramCount = 0
    for parameterName, parameterValue in model.named_parameters():
        if parameterValue.requires_grad:
            n = np.prod(parameterValue.size())
            if parameterValue.dim() <= 1:
            print(parameterName, ':', n)
            else:
                layerSize = list(parameterValue.size())
                if (len(layerSize) > 0):
Esempio n. 9
0
def run(split_path,dataset_path,n_train=None,n_val=None,n_epochs=1000):

	storage_dir="Info"
	if not os.path.exists(storage_dir):
		os.makedirs(storage_dir)

	if os.path.exists(os.path.join(storage_dir,"checkpoints")):
		shutil.rmtree(os.path.join(storage_dir,"checkpoints"))

	if os.path.exists(os.path.join(storage_dir,"log.csv")):
		os.remove(os.path.join(storage_dir,"log.csv"))

	if os.path.exists(os.path.join(storage_dir,"best_model")):
		os.remove(os.path.join(storage_dir,"best_model"))

	data=MD17(dataset_path)

	atoms,properties=data.get_properties(0)

	train,val,test=spk.train_test_split(
		data=data,
		split_file=split_path,
		)
	
	train_loader = spk.AtomsLoader(train, batch_size=100, shuffle=True)
	val_loader = spk.AtomsLoader(val, batch_size=100)

	means, stddevs = train_loader.get_statistics(
		spk.datasets.MD17.energy, divide_by_atoms=True
	)

	with open("out.txt","w+") as file:
		file.write("IN MD17_train")

	print('Mean atomization energy / atom:      {:12.4f} [kcal/mol]'.format(means[MD17.energy][0]))
	print('Std. dev. atomization energy / atom: {:12.4f} [kcal/mol]'.format(stddevs[MD17.energy][0]))

	n_features=64
	schnet = spk.representation.SchNet(
		n_atom_basis=n_features,
		n_filters=n_features,
		n_gaussians=25,
		n_interactions=6,
		cutoff=5.,
		cutoff_network=spk.nn.cutoff.CosineCutoff
	)


	energy_model = spk.atomistic.Atomwise(
		n_in=n_features,
		property=MD17.energy,
		mean=means[MD17.energy],
		stddev=stddevs[MD17.energy],
		derivative=MD17.forces,
		negative_dr=True
	)

	model = spk.AtomisticModel(representation=schnet, output_modules=energy_model)

	# tradeoff
	rho_tradeoff = 0.1
	optimizer=Adam(model.parameters(),lr=1e-3)

	# loss function
	def loss(batch, result):
		# compute the mean squared error on the energies
		diff_energy = batch[MD17.energy]-result[MD17.energy]
		err_sq_energy = torch.mean(diff_energy ** 2)

		# compute the mean squared error on the forces
		diff_forces = batch[MD17.forces]-result[MD17.forces]
		err_sq_forces = torch.mean(diff_forces ** 2)

		# build the combined loss function
		err_sq = rho_tradeoff*err_sq_energy + (1-rho_tradeoff)*err_sq_forces

		return err_sq


	# set up metrics
	metrics = [
		spk.metrics.MeanAbsoluteError(MD17.energy),
		spk.metrics.MeanAbsoluteError(MD17.forces)
	]

	# construct hooks
	hooks = [
		trn.CSVHook(log_path=storage_dir, metrics=metrics),
		trn.ReduceLROnPlateauHook(
			optimizer,
			patience=150, factor=0.8, min_lr=1e-6,
			stop_after_min=True
		)
	]

	trainer = trn.Trainer(
		model_path=storage_dir,
		model=model,
		hooks=hooks,
		loss_fn=loss,
		optimizer=optimizer,
		train_loader=train_loader,
		validation_loader=val_loader,
	)

	# check if a GPU is available and use a CPU otherwise
	if torch.cuda.is_available():
		device = "cuda"
	else:
		device = "cpu"

	# determine number of epochs and train
	trainer.train(
		device=device,
		n_epochs=n_epochs 
		)

	os.rename(os.path.join(storage_dir,"best_model"),os.path.join(storage_dir,"model_new"))
Esempio n. 10
0
def schnet_train_default(self, train_indices, model_path, old_model_path,
                         schnet_args):

    import schnetpack as spk
    import schnetpack.train as trn
    import torch

    n_val = schnet_args.get("n_val", 100)

    #  LOADING train, val, test
    if type(train_indices) == int:
        n_train = train_indices

        # Preparing storage
        storage = os.path.join(self.temp_dir, f"schnet_{n_train}")
        if not os.path.exists(storage):
            os.mkdir(storage)
        split_path = os.path.join(storage, "split.npz")

        train, val, test = spk.train_test_split(data=self.dataset,
                                                num_train=n_train,
                                                num_val=n_val,
                                                split_file=split_path)

    else:
        n_train = len(train_indices)

        # Preparing storage
        storage = os.path.join(self.temp_dir, f"schnet_{n_train}")
        if not os.path.exists(storage):
            os.mkdir(storage)
        split_path = os.path.join(storage, "split.npz")

        all_ind = np.arange(len(self.dataset))

        #  train
        train_ind = train_indices
        all_ind = np.delete(all_ind, train_ind)

        # val
        val_ind_ind = np.random.choice(np.arange(len(all_ind)),
                                       n_val,
                                       replace=False)
        val_ind = all_ind[val_ind_ind]
        all_ind = np.delete(all_ind, val_ind_ind)

        split_dict = {
            "train_idx": train_ind,
            "val_idx": val_ind,
            "test_idx": all_ind,
        }
        np.savez_compressed(split_path, **split_dict)

        train, val, test = spk.train_test_split(data=self.dataset,
                                                split_file=split_path)

    print_ongoing_process(f"Preparing SchNet training, {len(train)} points",
                          True)

    data = self.dataset

    batch_size = schnet_args.get("batch_size", 10)
    n_features = schnet_args.get("n_features", 64)
    n_gaussians = schnet_args.get("n_gaussians", 25)
    n_interactions = schnet_args.get("n_interactions", 6)
    cutoff = schnet_args.get("cutoff", 5.0)
    learning_rate = schnet_args.get("learning_rate", 1e-3)
    rho_tradeoff = schnet_args.get("rho_tradeoff", 0.1)
    patience = schnet_args.get("patience", 5)
    n_epochs = schnet_args.get("n_epochs", 100)

    #  PRINTING INFO
    i = {}
    i["batch_size"], i["n_features"] = batch_size, n_features
    i["n_gaussians"], i["n_interactions"] = n_gaussians, n_interactions
    i["cutoff"], i["learning_rate"] = cutoff, learning_rate
    i["rho_tradeoff"], i["patience"] = rho_tradeoff, patience
    i["n_epochs"], i["n_val"] = n_epochs, n_val
    print_table("Parameters", None, None, i, width=20)
    print()

    train_loader = spk.AtomsLoader(train, shuffle=True, batch_size=batch_size)
    val_loader = spk.AtomsLoader(val, batch_size=batch_size)

    #  STATISTICS + PRINTS
    means, stddevs = train_loader.get_statistics("energy",
                                                 divide_by_atoms=True)
    print_info(
        "Mean atomization energy / atom:      {:12.4f} [kcal/mol]".format(
            means["energy"][0]))
    print_info(
        "Std. dev. atomization energy / atom: {:12.4f} [kcal/mol]".format(
            stddevs["energy"][0]))

    #  LOADING MODEL
    print_ongoing_process("Loading representation and model")
    schnet = spk.representation.SchNet(
        n_atom_basis=n_features,
        n_filters=n_features,
        n_gaussians=n_gaussians,
        n_interactions=n_interactions,
        cutoff=cutoff,
        cutoff_network=spk.nn.cutoff.CosineCutoff,
    )

    energy_model = spk.atomistic.Atomwise(
        n_in=n_features,
        property="energy",
        mean=means["energy"],
        stddev=stddevs["energy"],
        derivative="forces",
        negative_dr=True,
    )

    model = spk.AtomisticModel(representation=schnet,
                               output_modules=energy_model)
    print_ongoing_process("Loading representation and model", True)

    #  OPTIMIZER AND LOSS
    print_ongoing_process("Defining loss function and optimizer")
    from torch.optim import Adam

    optimizer = Adam(model.parameters(), lr=learning_rate)

    def loss(batch, result):

        # compute the mean squared error on the energies
        diff_energy = batch["energy"] - result["energy"]
        err_sq_energy = torch.mean(diff_energy**2)

        # compute the mean squared error on the forces
        diff_forces = batch["forces"] - result["forces"]
        err_sq_forces = torch.mean(diff_forces**2)

        # build the combined loss function
        err_sq = rho_tradeoff * err_sq_energy + (1 -
                                                 rho_tradeoff) * err_sq_forces

        return err_sq

    print_ongoing_process("Defining loss function and optimizer", True)

    # METRICS AND HOOKS
    print_ongoing_process("Setting up metrics and hooks")
    metrics = [
        spk.metrics.MeanAbsoluteError("energy"),
        spk.metrics.MeanAbsoluteError("forces"),
    ]

    hooks = [
        trn.CSVHook(log_path=storage, metrics=metrics),
        trn.ReduceLROnPlateauHook(optimizer,
                                  patience=5,
                                  factor=0.8,
                                  min_lr=1e-6,
                                  stop_after_min=True),
    ]
    print_ongoing_process("Setting up metrics and hooks", True)

    print_ongoing_process("Setting up trainer")

    trainer = trn.Trainer(
        model_path=storage,
        model=model,
        hooks=hooks,
        loss_fn=loss,
        optimizer=optimizer,
        train_loader=train_loader,
        validation_loader=val_loader,
    )

    print_ongoing_process("Setting up trainer", True)

    if torch.cuda.is_available():
        device = "cuda"
        print_info(f"Cuda cores found, training on GPU")

    else:
        device = "cpu"
        print_info(f"No cuda cores found, training on CPU")

    print_ongoing_process(f"Training {n_epochs} ecpochs, out in {storage}")
    trainer.train(device=device, n_epochs=n_epochs)
    print_ongoing_process(f"Training {n_epochs} epochs, out in {storage}",
                          True)

    os.mkdir(model_path)

    os.rename(os.path.join(storage, "best_model"),
              os.path.join(model_path, "model"))
    shutil.copy(split_path, os.path.join(model_path, "split.npz"))
Esempio n. 11
0
def create_model(args, atomrefs, means, stddevs, properties, avg_n_atoms):
    ssp = rescaled_act.ShiftedSoftplus(beta=args.beta)
    kernel_conv = create_kernel_conv(
        cutoff=args.rad_maxr,
        n_bases=args.rad_nb,
        n_neurons=args.rad_h,
        n_layers=args.rad_L,
        act=ssp,
        radial_model=args.radial_model
    )

    sp = rescaled_act.Softplus(beta=args.beta)
    if args.res:
        net = ResNetwork(
            kernel_conv=kernel_conv,
            embed=args.embed,
            l0=args.l0,
            l1=args.l1,
            l2=args.l2,
            l3=args.l3,
            L=args.L,
            scalar_act=sp,
            gate_act=rescaled_act.sigmoid,
            avg_n_atoms=avg_n_atoms
        )
    else:
        net = Network(
            kernel_conv=kernel_conv,
            embed=args.embed,
            l0=args.l0,
            l1=args.l1,
            l2=args.l2,
            l3=args.l3,
            L=args.L,
            scalar_act=sp,
            gate_act=rescaled_act.sigmoid,
            avg_n_atoms=avg_n_atoms
        )

    ident = torch.nn.Identity()

    if args.mlp_out:
        outnet = OutputMLPNetwork(
            kernel_conv=kernel_conv,
            previous_Rs=net.Rs[-1],
            l0=args.outnet_l0,
            l1=args.outnet_l1,
            l2=args.outnet_l2,
            l3=args.outnet_l3,
            L=args.outnet_L,
            scalar_act=sp,
            gate_act=rescaled_act.sigmoid,
            mlp_h=args.outnet_neurons,
            mlp_L=args.outnet_layers,
            avg_n_atoms=avg_n_atoms
        )
    else:
        outnet = OutputScalarNetwork(
            kernel_conv=kernel_conv,
            previous_Rs=net.Rs[-1],
            scalar_act=ident,
            avg_n_atoms=avg_n_atoms
        )

    output_modules = [
        spk.atomistic.Atomwise(
            property=prop,
            mean=means[prop],
            stddev=stddevs[prop],
            atomref=atomrefs[prop],
            outnet=outnet,
            # aggregation_mode='sum'
        ) for prop in properties
    ]
    model = spk.AtomisticModel(net, output_modules)
    return model
Esempio n. 12
0
    def train(self,
              resultfolder,
              traindb='Data/dataset_10_12_train_combined.db',
              benchdb='Data/dataset_10_12_test.db',
              traindata='../../Data/combined1618/',
              benchdata='../../Data/test/',
              indexpath='../../Data/INDEX_refined_data.2016.2018',
              properties=['KD'],
              threshold=10,
              cutoff=8,
              numVal=150,
              featureset=False,
              trainBatchsize=8,
              valBatchsize=1,
              benchBatchsize=1,
              natoms=None,
              props=False,
              ntrain=4444,
              ntest=290,
              use_noise=False,
              noise_mean=0.0,
              noise_std=0.1,
              chargeEmbedding=True,
              ownFeatures=False,
              nFeatures=8,
              finalFeature=None,
              max_z=200,
              n_atom_basis=20,
              n_filters=32,
              n_gaussians=25,
              normalize_filter=False,
              coupled_interactions=False,
              trainable_gaussians=False,
              n_interactions=5,
              distanceCutoff=2.5,
              cutoff_network=schnetpack.nn.cutoff.CosineCutoff,
              outputIn=32,
              outAggregation='avg',
              outLayer=2,
              outMode='postaggregate',
              outAct=schnetpack.nn.activations.shifted_softplus,
              outOutAct=None,
              n_acc_steps=8,
              remember=10,
              ensembleModel=False,
              n_epochs=150,
              lr=1e-3,
              weight_decay=0,
              train_loader=None,
              val_loader=None,
              splitfile=None,
              noProtons=False):

        print('Device: ', torch.cuda.current_device())
        torch.cuda.empty_cache()

        # Define Folder for Results
        Resultfolder = resultfolder
        #You can activate shifted sigmoid
        #act = ShiftedSigmoid()
        if train_loader is None or val_loader is None:
            f = open("log.txt", "a")
            f.writelines(
                str(datetime.datetime.now()) + ' ' + Resultfolder +
                ' create loader by its own' + '\n')
            f.close()

            train_loader, val_loader, bench_loader = self.createDataloader(
                traindb=traindb,
                benchdb=benchdb,
                traindata=traindata,
                benchdata=benchdata,
                indexpath=indexpath,
                properties=properties,
                threshold=threshold,
                cutoff=cutoff,
                numVal=numVal,
                featureset=featureset,
                trainBatchsize=trainBatchsize,
                valBatchsize=valBatchsize,
                benchBatchsize=benchBatchsize,
                natoms=natoms,
                props=props,
                ntrain=ntrain,
                ntest=ntest,
                splitfile=splitfile,
                noProtons=noProtons)

        model = schnetpack.representation.SchNet(
            use_noise=use_noise,
            noise_mean=noise_mean,
            noise_std=noise_std,
            chargeEmbedding=chargeEmbedding,
            ownFeatures=ownFeatures,
            nFeatures=nFeatures,
            finalFeature=finalFeature,
            max_z=max_z,
            n_atom_basis=n_atom_basis,
            n_filters=n_filters,
            n_gaussians=n_gaussians,
            normalize_filter=normalize_filter,
            coupled_interactions=coupled_interactions,
            trainable_gaussians=trainable_gaussians,
            n_interactions=n_interactions,
            cutoff=distanceCutoff,
            cutoff_network=cutoff_network)
        d = schnetpack.atomistic.Atomwise(n_in=outputIn,
                                          aggregation_mode=outAggregation,
                                          n_layers=outLayer,
                                          mode=outMode,
                                          activation=outAct,
                                          output_activation=outOutAct)

        model = schnetpack.AtomisticModel(model, d)
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

        print(model)
        print('number of trainable parameters =',
              SchnetTraining.count_parameters(model))

        # Defines Metrics, Hooks and Loss
        loss = SchnetTraining.mse_loss()
        # MSE Metrics for Validation
        metrics = [schnetpack.metrics.MeanSquaredError('KD', model_output='y')]
        # CSVHook -> Creates a log-file
        # ReduceLROnPlateauHook -> LR-Decay
        hooks = [
            schnetpack.train.CSVHook(log_path=Resultfolder, metrics=metrics),
            schnetpack.train.ReduceLROnPlateauHook(optimizer,
                                                   patience=7,
                                                   factor=0.5,
                                                   min_lr=1e-5),
            schnetpack.train.EarlyStoppingHook(patience=40,
                                               threshold_ratio=0.001)
        ]

        # Create Trainer -> n_acc_steps is for accumulating gradient, if the graphiccard can not handle big batch-sizes
        trainer = schnetpack.train.Trainer(model_path=Resultfolder,
                                           model=model,
                                           loss_fn=loss,
                                           train_loader=train_loader,
                                           optimizer=optimizer,
                                           validation_loader=val_loader,
                                           hooks=hooks,
                                           n_acc_steps=n_acc_steps,
                                           keep_n_checkpoints=10,
                                           checkpoint_interval=5,
                                           remember=remember,
                                           ensembleModel=ensembleModel)

        # Start training on cuda for infinite epochs -> Use early stopping
        trainer.train(device='cuda', n_epochs=n_epochs)
Esempio n. 13
0
def atomistic_model(schnet, output_modules):
    return spk.AtomisticModel(schnet, output_modules)
Esempio n. 14
0
    cutoff=SchNet_cutoff,
    cutoff_network=spk.nn.cutoff.CosineCutoff,
    trainable_gaussians=trainable_gaussians,
)

# %%

energy_model = spk.atomistic.Atomwise(n_in=n_features,
                                      property='energy',
                                      mean=means['energy'],
                                      stddev=stddevs['energy'],
                                      derivative='forces',
                                      negative_dr=True)

# %%
model = spk.AtomisticModel(representation=schnet, output_modules=energy_model)

# %% Multi-GPUs
import torch
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = torch.nn.DataParallel(model)

# %%
import torch


# loss function the same as https://aip.scitation.org/doi/pdf/10.1063/1.5019779
def loss(batch, result):
    # compute the mean squared error on the energies