예제 #1
0
def train_model(args, model, train_loader, val_loader):

    # before setting up the trainer, remove previous training checkpoints and logs
    if os.path.exists(os.path.join(args.model_path, 'checkpoints')):
        shutil.rmtree(os.path.join(args.model_path, 'checkpoints'))

    if os.path.exists(os.path.join(args.model_path, 'log.csv')):
        os.remove(os.path.join(args.model_path, 'log.csv'))

    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = Adam(trainable_params, lr=args.lr, weight_decay=4e-7)
    metrics = [
        spk.train.metrics.MeanAbsoluteError(args.property, args.property),
        spk.train.metrics.RootMeanSquaredError(args.property, args.property),
    ]
    hooks = [
        trn.CSVHook(log_path=args.model_path, metrics=metrics),
        trn.ReduceLROnPlateauHook(optimizer,
                                  patience=25,
                                  factor=0.8,
                                  min_lr=1e-6,
                                  window_length=1,
                                  stop_after_min=True)
    ]

    loss = simple_loss_fn(args)
    trainer = trn.Trainer(
        model_path=args.model_path,
        model=model,
        hooks=hooks,
        loss_fn=loss,
        optimizer=optimizer,
        train_loader=train_loader,
        validation_loader=val_loader,
    )

    return trainer
예제 #2
0
                     message_steps=args.num_messages,
                     output_layers=args.output_layers,
                     reduce_fn=args.readout_fn,
                     atomwise=args.atomwise,
                     mean=mean['ip'],
                     std=std['ip'])

    # Train the model
    #  Following:
    init_learn_rate = 1e-4
    opt = optim.Adam(model.parameters(), lr=init_learn_rate)

    loss = trn.build_mse_loss(['ip'])
    metrics = [spk.metrics.MeanSquaredError('ip')]
    hooks = [
        trn.CSVHook(log_path=test_dir, metrics=metrics),
        trn.ReduceLROnPlateauHook(opt,
                                  patience=args.num_epochs // 8,
                                  factor=0.8,
                                  min_lr=1e-6,
                                  stop_after_min=True)
    ]

    trainer = trn.Trainer(
        model_path=test_dir,
        model=model,
        hooks=hooks,
        loss_fn=loss,
        optimizer=opt,
        train_loader=train_loader,
        validation_loader=valid_loader,
예제 #3
0
def run(split_path,dataset_path,n_train=None,n_val=None,n_epochs=1000):

	storage_dir="Info"
	if not os.path.exists(storage_dir):
		os.makedirs(storage_dir)

	if os.path.exists(os.path.join(storage_dir,"checkpoints")):
		shutil.rmtree(os.path.join(storage_dir,"checkpoints"))

	if os.path.exists(os.path.join(storage_dir,"log.csv")):
		os.remove(os.path.join(storage_dir,"log.csv"))

	if os.path.exists(os.path.join(storage_dir,"best_model")):
		os.remove(os.path.join(storage_dir,"best_model"))

	data=MD17(dataset_path)

	atoms,properties=data.get_properties(0)

	train,val,test=spk.train_test_split(
		data=data,
		split_file=split_path,
		)
	
	train_loader = spk.AtomsLoader(train, batch_size=100, shuffle=True)
	val_loader = spk.AtomsLoader(val, batch_size=100)

	means, stddevs = train_loader.get_statistics(
		spk.datasets.MD17.energy, divide_by_atoms=True
	)

	with open("out.txt","w+") as file:
		file.write("IN MD17_train")

	print('Mean atomization energy / atom:      {:12.4f} [kcal/mol]'.format(means[MD17.energy][0]))
	print('Std. dev. atomization energy / atom: {:12.4f} [kcal/mol]'.format(stddevs[MD17.energy][0]))

	n_features=64
	schnet = spk.representation.SchNet(
		n_atom_basis=n_features,
		n_filters=n_features,
		n_gaussians=25,
		n_interactions=6,
		cutoff=5.,
		cutoff_network=spk.nn.cutoff.CosineCutoff
	)


	energy_model = spk.atomistic.Atomwise(
		n_in=n_features,
		property=MD17.energy,
		mean=means[MD17.energy],
		stddev=stddevs[MD17.energy],
		derivative=MD17.forces,
		negative_dr=True
	)

	model = spk.AtomisticModel(representation=schnet, output_modules=energy_model)

	# tradeoff
	rho_tradeoff = 0.1
	optimizer=Adam(model.parameters(),lr=1e-3)

	# loss function
	def loss(batch, result):
		# compute the mean squared error on the energies
		diff_energy = batch[MD17.energy]-result[MD17.energy]
		err_sq_energy = torch.mean(diff_energy ** 2)

		# compute the mean squared error on the forces
		diff_forces = batch[MD17.forces]-result[MD17.forces]
		err_sq_forces = torch.mean(diff_forces ** 2)

		# build the combined loss function
		err_sq = rho_tradeoff*err_sq_energy + (1-rho_tradeoff)*err_sq_forces

		return err_sq


	# set up metrics
	metrics = [
		spk.metrics.MeanAbsoluteError(MD17.energy),
		spk.metrics.MeanAbsoluteError(MD17.forces)
	]

	# construct hooks
	hooks = [
		trn.CSVHook(log_path=storage_dir, metrics=metrics),
		trn.ReduceLROnPlateauHook(
			optimizer,
			patience=150, factor=0.8, min_lr=1e-6,
			stop_after_min=True
		)
	]

	trainer = trn.Trainer(
		model_path=storage_dir,
		model=model,
		hooks=hooks,
		loss_fn=loss,
		optimizer=optimizer,
		train_loader=train_loader,
		validation_loader=val_loader,
	)

	# check if a GPU is available and use a CPU otherwise
	if torch.cuda.is_available():
		device = "cuda"
	else:
		device = "cpu"

	# determine number of epochs and train
	trainer.train(
		device=device,
		n_epochs=n_epochs 
		)

	os.rename(os.path.join(storage_dir,"best_model"),os.path.join(storage_dir,"model_new"))
예제 #4
0
파일: models.py 프로젝트: fonsecag/MLFF
def schnet_train_default(self, train_indices, model_path, old_model_path,
                         schnet_args):

    import schnetpack as spk
    import schnetpack.train as trn
    import torch

    n_val = schnet_args.get("n_val", 100)

    #  LOADING train, val, test
    if type(train_indices) == int:
        n_train = train_indices

        # Preparing storage
        storage = os.path.join(self.temp_dir, f"schnet_{n_train}")
        if not os.path.exists(storage):
            os.mkdir(storage)
        split_path = os.path.join(storage, "split.npz")

        train, val, test = spk.train_test_split(data=self.dataset,
                                                num_train=n_train,
                                                num_val=n_val,
                                                split_file=split_path)

    else:
        n_train = len(train_indices)

        # Preparing storage
        storage = os.path.join(self.temp_dir, f"schnet_{n_train}")
        if not os.path.exists(storage):
            os.mkdir(storage)
        split_path = os.path.join(storage, "split.npz")

        all_ind = np.arange(len(self.dataset))

        #  train
        train_ind = train_indices
        all_ind = np.delete(all_ind, train_ind)

        # val
        val_ind_ind = np.random.choice(np.arange(len(all_ind)),
                                       n_val,
                                       replace=False)
        val_ind = all_ind[val_ind_ind]
        all_ind = np.delete(all_ind, val_ind_ind)

        split_dict = {
            "train_idx": train_ind,
            "val_idx": val_ind,
            "test_idx": all_ind,
        }
        np.savez_compressed(split_path, **split_dict)

        train, val, test = spk.train_test_split(data=self.dataset,
                                                split_file=split_path)

    print_ongoing_process(f"Preparing SchNet training, {len(train)} points",
                          True)

    data = self.dataset

    batch_size = schnet_args.get("batch_size", 10)
    n_features = schnet_args.get("n_features", 64)
    n_gaussians = schnet_args.get("n_gaussians", 25)
    n_interactions = schnet_args.get("n_interactions", 6)
    cutoff = schnet_args.get("cutoff", 5.0)
    learning_rate = schnet_args.get("learning_rate", 1e-3)
    rho_tradeoff = schnet_args.get("rho_tradeoff", 0.1)
    patience = schnet_args.get("patience", 5)
    n_epochs = schnet_args.get("n_epochs", 100)

    #  PRINTING INFO
    i = {}
    i["batch_size"], i["n_features"] = batch_size, n_features
    i["n_gaussians"], i["n_interactions"] = n_gaussians, n_interactions
    i["cutoff"], i["learning_rate"] = cutoff, learning_rate
    i["rho_tradeoff"], i["patience"] = rho_tradeoff, patience
    i["n_epochs"], i["n_val"] = n_epochs, n_val
    print_table("Parameters", None, None, i, width=20)
    print()

    train_loader = spk.AtomsLoader(train, shuffle=True, batch_size=batch_size)
    val_loader = spk.AtomsLoader(val, batch_size=batch_size)

    #  STATISTICS + PRINTS
    means, stddevs = train_loader.get_statistics("energy",
                                                 divide_by_atoms=True)
    print_info(
        "Mean atomization energy / atom:      {:12.4f} [kcal/mol]".format(
            means["energy"][0]))
    print_info(
        "Std. dev. atomization energy / atom: {:12.4f} [kcal/mol]".format(
            stddevs["energy"][0]))

    #  LOADING MODEL
    print_ongoing_process("Loading representation and model")
    schnet = spk.representation.SchNet(
        n_atom_basis=n_features,
        n_filters=n_features,
        n_gaussians=n_gaussians,
        n_interactions=n_interactions,
        cutoff=cutoff,
        cutoff_network=spk.nn.cutoff.CosineCutoff,
    )

    energy_model = spk.atomistic.Atomwise(
        n_in=n_features,
        property="energy",
        mean=means["energy"],
        stddev=stddevs["energy"],
        derivative="forces",
        negative_dr=True,
    )

    model = spk.AtomisticModel(representation=schnet,
                               output_modules=energy_model)
    print_ongoing_process("Loading representation and model", True)

    #  OPTIMIZER AND LOSS
    print_ongoing_process("Defining loss function and optimizer")
    from torch.optim import Adam

    optimizer = Adam(model.parameters(), lr=learning_rate)

    def loss(batch, result):

        # compute the mean squared error on the energies
        diff_energy = batch["energy"] - result["energy"]
        err_sq_energy = torch.mean(diff_energy**2)

        # compute the mean squared error on the forces
        diff_forces = batch["forces"] - result["forces"]
        err_sq_forces = torch.mean(diff_forces**2)

        # build the combined loss function
        err_sq = rho_tradeoff * err_sq_energy + (1 -
                                                 rho_tradeoff) * err_sq_forces

        return err_sq

    print_ongoing_process("Defining loss function and optimizer", True)

    # METRICS AND HOOKS
    print_ongoing_process("Setting up metrics and hooks")
    metrics = [
        spk.metrics.MeanAbsoluteError("energy"),
        spk.metrics.MeanAbsoluteError("forces"),
    ]

    hooks = [
        trn.CSVHook(log_path=storage, metrics=metrics),
        trn.ReduceLROnPlateauHook(optimizer,
                                  patience=5,
                                  factor=0.8,
                                  min_lr=1e-6,
                                  stop_after_min=True),
    ]
    print_ongoing_process("Setting up metrics and hooks", True)

    print_ongoing_process("Setting up trainer")

    trainer = trn.Trainer(
        model_path=storage,
        model=model,
        hooks=hooks,
        loss_fn=loss,
        optimizer=optimizer,
        train_loader=train_loader,
        validation_loader=val_loader,
    )

    print_ongoing_process("Setting up trainer", True)

    if torch.cuda.is_available():
        device = "cuda"
        print_info(f"Cuda cores found, training on GPU")

    else:
        device = "cpu"
        print_info(f"No cuda cores found, training on CPU")

    print_ongoing_process(f"Training {n_epochs} ecpochs, out in {storage}")
    trainer.train(device=device, n_epochs=n_epochs)
    print_ongoing_process(f"Training {n_epochs} epochs, out in {storage}",
                          True)

    os.mkdir(model_path)

    os.rename(os.path.join(storage, "best_model"),
              os.path.join(model_path, "model"))
    shutil.copy(split_path, os.path.join(model_path, "split.npz"))
예제 #5
0
def train_schnet(
    model: Union[TorchMessage, torch.nn.Module, Path],
    database: Dict[str, float],
    num_epochs: int,
    reset_weights: bool = True,
    property_name: str = 'output',
    test_set: Optional[List[str]] = None,
    device: str = 'cpu',
    batch_size: int = 32,
    validation_split: float = 0.1,
    bootstrap: bool = False,
    random_state: int = 1,
    learning_rate: float = 1e-3,
    patience: int = None,
    timeout: float = None
) -> Union[Tuple[TorchMessage, pd.DataFrame], Tuple[TorchMessage, pd.DataFrame,
                                                    List[float]]]:
    """Train a SchNet model

    Args:
        model: Model to be retrained
        database: Mapping of XYZ format structure to property
        num_epochs: Number of training epochs
        property_name: Name of the property being predicted
        reset_weights: Whether to re-initialize weights before training, or start training from previous
        test_set: Hold-out set. If provided, function will return the performance of the model on those weights
        device: Device (e.g., 'cuda', 'cpu') used for training
        batch_size: Batch size during training
        validation_split: Fraction to training set to use for the validation loss
        bootstrap: Whether to take a bootstrap sample of the training set before training
        random_state: Random seed used for generating validation set and bootstrap sampling
        learning_rate: Initial learning rate for optimizer
        patience: Patience until learning rate is lowered. Default: epochs / 8
        timeout: Maximum training time in seconds
    Returns:
        - model: Retrained model
        - history: Training history
        - test_pred: Predictions on ``test_set``, if provided
    """

    # Make sure the models are converted to Torch models
    if isinstance(model, TorchMessage):
        model = model.get_model(device)
    elif isinstance(model, (Path, str)):
        model = torch.load(model,
                           map_location='cpu')  # Load to main memory first

    # If desired, re-initialize weights
    if reset_weights:
        for module in model.modules():
            if hasattr(module, 'reset_parameters'):
                module.reset_parameters()

    # Separate the database into molecules and properties
    xyz, y = zip(*database.items())
    xyz = np.array(xyz)
    y = np.array(y)

    # Convert the xyz files to ase Atoms
    atoms = np.array([next(read_xyz(StringIO(x), slice(None))) for x in xyz])

    # Make the training and validation splits
    rng = np.random.RandomState(random_state)
    train_split = rng.rand(len(xyz)) > validation_split
    train_X = atoms[train_split]
    train_y = y[train_split]
    valid_X = atoms[~train_split]
    valid_y = y[~train_split]

    # Perform a bootstrap sample of the training data
    if bootstrap:
        sample = rng.choice(len(train_X), size=(len(train_X), ), replace=True)
        train_X = train_X[sample]
        train_y = train_y[sample]

    # Start the training process
    with TemporaryDirectory() as td:
        # Save the data to an ASE Atoms database
        train_file = os.path.join(td, 'train_data.db')
        db = AtomsData(train_file, available_properties=[property_name])
        db.add_systems(train_X, [{property_name: i} for i in train_y])
        train_loader = AtomsLoader(db, batch_size=batch_size, shuffle=True)

        valid_file = os.path.join(td, 'valid_data.db')
        db = AtomsData(valid_file, available_properties=[property_name])
        db.add_systems(valid_X, [{property_name: i} for i in valid_y])
        valid_loader = AtomsLoader(db, batch_size=batch_size)

        # Make the trainer
        opt = optim.Adam(model.parameters(), lr=learning_rate)

        loss = trn.build_mse_loss(['delta'])
        metrics = [spk.metrics.MeanSquaredError('delta')]
        if patience is None:
            patience = num_epochs // 8
        hooks = [
            trn.CSVHook(log_path=td, metrics=metrics),
            trn.ReduceLROnPlateauHook(opt,
                                      patience=patience,
                                      factor=0.8,
                                      min_lr=1e-6,
                                      stop_after_min=True)
        ]

        if timeout is not None:
            hooks.append(TimeoutHook(timeout))

        trainer = trn.Trainer(
            model_path=td,
            model=model,
            hooks=hooks,
            loss_fn=loss,
            optimizer=opt,
            train_loader=train_loader,
            validation_loader=valid_loader,
            checkpoint_interval=num_epochs + 1  # Turns off checkpointing
        )

        trainer.train(device, n_epochs=num_epochs)

        # Load in the best model
        model = torch.load(os.path.join(td, 'best_model'))

        # If desired, report the performance on a test set
        test_pred = None
        if test_set is not None:
            test_pred = evaluate_schnet([model],
                                        test_set,
                                        property_name=property_name,
                                        batch_size=batch_size,
                                        device=device)

        # Move the model off of the GPU to save memory
        if 'cuda' in device:
            model.to('cpu')

        # Load in the training results
        train_results = pd.read_csv(os.path.join(td, 'log.csv'))

        # Return the results
        if test_pred is None:
            return TorchMessage(model), train_results
        else:
            return TorchMessage(model), train_results, test_pred[:, 0].tolist()
예제 #6
0
# %%
import schnetpack.train as trn

# set up metrics
metrics = [
    spk.metrics.MeanAbsoluteError('energy'),
    spk.metrics.MeanAbsoluteError('forces'),
    spk.metrics.RootMeanSquaredError('energy'),
    spk.metrics.RootMeanSquaredError('forces'),
]

# construct hooks
hooks = [
    WandBHook(log_path=forcetut, metrics=metrics),
    trn.CSVHook(log_path=forcetut, metrics=metrics),
    trn.ReduceLROnPlateauHook(optimizer,
                              patience=patience,
                              factor=factor,
                              min_lr=min_lr,
                              stop_after_min=True)
]

# %%
trainer = trn.Trainer(
    model_path=forcetut,
    model=model,
    hooks=hooks,
    loss_fn=loss,
    optimizer=optimizer,
    train_loader=train_loader,