def train_model(args, model, train_loader, val_loader): # before setting up the trainer, remove previous training checkpoints and logs if os.path.exists(os.path.join(args.model_path, 'checkpoints')): shutil.rmtree(os.path.join(args.model_path, 'checkpoints')) if os.path.exists(os.path.join(args.model_path, 'log.csv')): os.remove(os.path.join(args.model_path, 'log.csv')) trainable_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer = Adam(trainable_params, lr=args.lr, weight_decay=4e-7) metrics = [ spk.train.metrics.MeanAbsoluteError(args.property, args.property), spk.train.metrics.RootMeanSquaredError(args.property, args.property), ] hooks = [ trn.CSVHook(log_path=args.model_path, metrics=metrics), trn.ReduceLROnPlateauHook(optimizer, patience=25, factor=0.8, min_lr=1e-6, window_length=1, stop_after_min=True) ] loss = simple_loss_fn(args) trainer = trn.Trainer( model_path=args.model_path, model=model, hooks=hooks, loss_fn=loss, optimizer=optimizer, train_loader=train_loader, validation_loader=val_loader, ) return trainer
message_steps=args.num_messages, output_layers=args.output_layers, reduce_fn=args.readout_fn, atomwise=args.atomwise, mean=mean['ip'], std=std['ip']) # Train the model # Following: init_learn_rate = 1e-4 opt = optim.Adam(model.parameters(), lr=init_learn_rate) loss = trn.build_mse_loss(['ip']) metrics = [spk.metrics.MeanSquaredError('ip')] hooks = [ trn.CSVHook(log_path=test_dir, metrics=metrics), trn.ReduceLROnPlateauHook(opt, patience=args.num_epochs // 8, factor=0.8, min_lr=1e-6, stop_after_min=True) ] trainer = trn.Trainer( model_path=test_dir, model=model, hooks=hooks, loss_fn=loss, optimizer=opt, train_loader=train_loader, validation_loader=valid_loader,
def run(split_path,dataset_path,n_train=None,n_val=None,n_epochs=1000): storage_dir="Info" if not os.path.exists(storage_dir): os.makedirs(storage_dir) if os.path.exists(os.path.join(storage_dir,"checkpoints")): shutil.rmtree(os.path.join(storage_dir,"checkpoints")) if os.path.exists(os.path.join(storage_dir,"log.csv")): os.remove(os.path.join(storage_dir,"log.csv")) if os.path.exists(os.path.join(storage_dir,"best_model")): os.remove(os.path.join(storage_dir,"best_model")) data=MD17(dataset_path) atoms,properties=data.get_properties(0) train,val,test=spk.train_test_split( data=data, split_file=split_path, ) train_loader = spk.AtomsLoader(train, batch_size=100, shuffle=True) val_loader = spk.AtomsLoader(val, batch_size=100) means, stddevs = train_loader.get_statistics( spk.datasets.MD17.energy, divide_by_atoms=True ) with open("out.txt","w+") as file: file.write("IN MD17_train") print('Mean atomization energy / atom: {:12.4f} [kcal/mol]'.format(means[MD17.energy][0])) print('Std. dev. atomization energy / atom: {:12.4f} [kcal/mol]'.format(stddevs[MD17.energy][0])) n_features=64 schnet = spk.representation.SchNet( n_atom_basis=n_features, n_filters=n_features, n_gaussians=25, n_interactions=6, cutoff=5., cutoff_network=spk.nn.cutoff.CosineCutoff ) energy_model = spk.atomistic.Atomwise( n_in=n_features, property=MD17.energy, mean=means[MD17.energy], stddev=stddevs[MD17.energy], derivative=MD17.forces, negative_dr=True ) model = spk.AtomisticModel(representation=schnet, output_modules=energy_model) # tradeoff rho_tradeoff = 0.1 optimizer=Adam(model.parameters(),lr=1e-3) # loss function def loss(batch, result): # compute the mean squared error on the energies diff_energy = batch[MD17.energy]-result[MD17.energy] err_sq_energy = torch.mean(diff_energy ** 2) # compute the mean squared error on the forces diff_forces = batch[MD17.forces]-result[MD17.forces] err_sq_forces = torch.mean(diff_forces ** 2) # build the combined loss function err_sq = rho_tradeoff*err_sq_energy + (1-rho_tradeoff)*err_sq_forces return err_sq # set up metrics metrics = [ spk.metrics.MeanAbsoluteError(MD17.energy), spk.metrics.MeanAbsoluteError(MD17.forces) ] # construct hooks hooks = [ trn.CSVHook(log_path=storage_dir, metrics=metrics), trn.ReduceLROnPlateauHook( optimizer, patience=150, factor=0.8, min_lr=1e-6, stop_after_min=True ) ] trainer = trn.Trainer( model_path=storage_dir, model=model, hooks=hooks, loss_fn=loss, optimizer=optimizer, train_loader=train_loader, validation_loader=val_loader, ) # check if a GPU is available and use a CPU otherwise if torch.cuda.is_available(): device = "cuda" else: device = "cpu" # determine number of epochs and train trainer.train( device=device, n_epochs=n_epochs ) os.rename(os.path.join(storage_dir,"best_model"),os.path.join(storage_dir,"model_new"))
def schnet_train_default(self, train_indices, model_path, old_model_path, schnet_args): import schnetpack as spk import schnetpack.train as trn import torch n_val = schnet_args.get("n_val", 100) # LOADING train, val, test if type(train_indices) == int: n_train = train_indices # Preparing storage storage = os.path.join(self.temp_dir, f"schnet_{n_train}") if not os.path.exists(storage): os.mkdir(storage) split_path = os.path.join(storage, "split.npz") train, val, test = spk.train_test_split(data=self.dataset, num_train=n_train, num_val=n_val, split_file=split_path) else: n_train = len(train_indices) # Preparing storage storage = os.path.join(self.temp_dir, f"schnet_{n_train}") if not os.path.exists(storage): os.mkdir(storage) split_path = os.path.join(storage, "split.npz") all_ind = np.arange(len(self.dataset)) # train train_ind = train_indices all_ind = np.delete(all_ind, train_ind) # val val_ind_ind = np.random.choice(np.arange(len(all_ind)), n_val, replace=False) val_ind = all_ind[val_ind_ind] all_ind = np.delete(all_ind, val_ind_ind) split_dict = { "train_idx": train_ind, "val_idx": val_ind, "test_idx": all_ind, } np.savez_compressed(split_path, **split_dict) train, val, test = spk.train_test_split(data=self.dataset, split_file=split_path) print_ongoing_process(f"Preparing SchNet training, {len(train)} points", True) data = self.dataset batch_size = schnet_args.get("batch_size", 10) n_features = schnet_args.get("n_features", 64) n_gaussians = schnet_args.get("n_gaussians", 25) n_interactions = schnet_args.get("n_interactions", 6) cutoff = schnet_args.get("cutoff", 5.0) learning_rate = schnet_args.get("learning_rate", 1e-3) rho_tradeoff = schnet_args.get("rho_tradeoff", 0.1) patience = schnet_args.get("patience", 5) n_epochs = schnet_args.get("n_epochs", 100) # PRINTING INFO i = {} i["batch_size"], i["n_features"] = batch_size, n_features i["n_gaussians"], i["n_interactions"] = n_gaussians, n_interactions i["cutoff"], i["learning_rate"] = cutoff, learning_rate i["rho_tradeoff"], i["patience"] = rho_tradeoff, patience i["n_epochs"], i["n_val"] = n_epochs, n_val print_table("Parameters", None, None, i, width=20) print() train_loader = spk.AtomsLoader(train, shuffle=True, batch_size=batch_size) val_loader = spk.AtomsLoader(val, batch_size=batch_size) # STATISTICS + PRINTS means, stddevs = train_loader.get_statistics("energy", divide_by_atoms=True) print_info( "Mean atomization energy / atom: {:12.4f} [kcal/mol]".format( means["energy"][0])) print_info( "Std. dev. atomization energy / atom: {:12.4f} [kcal/mol]".format( stddevs["energy"][0])) # LOADING MODEL print_ongoing_process("Loading representation and model") schnet = spk.representation.SchNet( n_atom_basis=n_features, n_filters=n_features, n_gaussians=n_gaussians, n_interactions=n_interactions, cutoff=cutoff, cutoff_network=spk.nn.cutoff.CosineCutoff, ) energy_model = spk.atomistic.Atomwise( n_in=n_features, property="energy", mean=means["energy"], stddev=stddevs["energy"], derivative="forces", negative_dr=True, ) model = spk.AtomisticModel(representation=schnet, output_modules=energy_model) print_ongoing_process("Loading representation and model", True) # OPTIMIZER AND LOSS print_ongoing_process("Defining loss function and optimizer") from torch.optim import Adam optimizer = Adam(model.parameters(), lr=learning_rate) def loss(batch, result): # compute the mean squared error on the energies diff_energy = batch["energy"] - result["energy"] err_sq_energy = torch.mean(diff_energy**2) # compute the mean squared error on the forces diff_forces = batch["forces"] - result["forces"] err_sq_forces = torch.mean(diff_forces**2) # build the combined loss function err_sq = rho_tradeoff * err_sq_energy + (1 - rho_tradeoff) * err_sq_forces return err_sq print_ongoing_process("Defining loss function and optimizer", True) # METRICS AND HOOKS print_ongoing_process("Setting up metrics and hooks") metrics = [ spk.metrics.MeanAbsoluteError("energy"), spk.metrics.MeanAbsoluteError("forces"), ] hooks = [ trn.CSVHook(log_path=storage, metrics=metrics), trn.ReduceLROnPlateauHook(optimizer, patience=5, factor=0.8, min_lr=1e-6, stop_after_min=True), ] print_ongoing_process("Setting up metrics and hooks", True) print_ongoing_process("Setting up trainer") trainer = trn.Trainer( model_path=storage, model=model, hooks=hooks, loss_fn=loss, optimizer=optimizer, train_loader=train_loader, validation_loader=val_loader, ) print_ongoing_process("Setting up trainer", True) if torch.cuda.is_available(): device = "cuda" print_info(f"Cuda cores found, training on GPU") else: device = "cpu" print_info(f"No cuda cores found, training on CPU") print_ongoing_process(f"Training {n_epochs} ecpochs, out in {storage}") trainer.train(device=device, n_epochs=n_epochs) print_ongoing_process(f"Training {n_epochs} epochs, out in {storage}", True) os.mkdir(model_path) os.rename(os.path.join(storage, "best_model"), os.path.join(model_path, "model")) shutil.copy(split_path, os.path.join(model_path, "split.npz"))
def train_schnet( model: Union[TorchMessage, torch.nn.Module, Path], database: Dict[str, float], num_epochs: int, reset_weights: bool = True, property_name: str = 'output', test_set: Optional[List[str]] = None, device: str = 'cpu', batch_size: int = 32, validation_split: float = 0.1, bootstrap: bool = False, random_state: int = 1, learning_rate: float = 1e-3, patience: int = None, timeout: float = None ) -> Union[Tuple[TorchMessage, pd.DataFrame], Tuple[TorchMessage, pd.DataFrame, List[float]]]: """Train a SchNet model Args: model: Model to be retrained database: Mapping of XYZ format structure to property num_epochs: Number of training epochs property_name: Name of the property being predicted reset_weights: Whether to re-initialize weights before training, or start training from previous test_set: Hold-out set. If provided, function will return the performance of the model on those weights device: Device (e.g., 'cuda', 'cpu') used for training batch_size: Batch size during training validation_split: Fraction to training set to use for the validation loss bootstrap: Whether to take a bootstrap sample of the training set before training random_state: Random seed used for generating validation set and bootstrap sampling learning_rate: Initial learning rate for optimizer patience: Patience until learning rate is lowered. Default: epochs / 8 timeout: Maximum training time in seconds Returns: - model: Retrained model - history: Training history - test_pred: Predictions on ``test_set``, if provided """ # Make sure the models are converted to Torch models if isinstance(model, TorchMessage): model = model.get_model(device) elif isinstance(model, (Path, str)): model = torch.load(model, map_location='cpu') # Load to main memory first # If desired, re-initialize weights if reset_weights: for module in model.modules(): if hasattr(module, 'reset_parameters'): module.reset_parameters() # Separate the database into molecules and properties xyz, y = zip(*database.items()) xyz = np.array(xyz) y = np.array(y) # Convert the xyz files to ase Atoms atoms = np.array([next(read_xyz(StringIO(x), slice(None))) for x in xyz]) # Make the training and validation splits rng = np.random.RandomState(random_state) train_split = rng.rand(len(xyz)) > validation_split train_X = atoms[train_split] train_y = y[train_split] valid_X = atoms[~train_split] valid_y = y[~train_split] # Perform a bootstrap sample of the training data if bootstrap: sample = rng.choice(len(train_X), size=(len(train_X), ), replace=True) train_X = train_X[sample] train_y = train_y[sample] # Start the training process with TemporaryDirectory() as td: # Save the data to an ASE Atoms database train_file = os.path.join(td, 'train_data.db') db = AtomsData(train_file, available_properties=[property_name]) db.add_systems(train_X, [{property_name: i} for i in train_y]) train_loader = AtomsLoader(db, batch_size=batch_size, shuffle=True) valid_file = os.path.join(td, 'valid_data.db') db = AtomsData(valid_file, available_properties=[property_name]) db.add_systems(valid_X, [{property_name: i} for i in valid_y]) valid_loader = AtomsLoader(db, batch_size=batch_size) # Make the trainer opt = optim.Adam(model.parameters(), lr=learning_rate) loss = trn.build_mse_loss(['delta']) metrics = [spk.metrics.MeanSquaredError('delta')] if patience is None: patience = num_epochs // 8 hooks = [ trn.CSVHook(log_path=td, metrics=metrics), trn.ReduceLROnPlateauHook(opt, patience=patience, factor=0.8, min_lr=1e-6, stop_after_min=True) ] if timeout is not None: hooks.append(TimeoutHook(timeout)) trainer = trn.Trainer( model_path=td, model=model, hooks=hooks, loss_fn=loss, optimizer=opt, train_loader=train_loader, validation_loader=valid_loader, checkpoint_interval=num_epochs + 1 # Turns off checkpointing ) trainer.train(device, n_epochs=num_epochs) # Load in the best model model = torch.load(os.path.join(td, 'best_model')) # If desired, report the performance on a test set test_pred = None if test_set is not None: test_pred = evaluate_schnet([model], test_set, property_name=property_name, batch_size=batch_size, device=device) # Move the model off of the GPU to save memory if 'cuda' in device: model.to('cpu') # Load in the training results train_results = pd.read_csv(os.path.join(td, 'log.csv')) # Return the results if test_pred is None: return TorchMessage(model), train_results else: return TorchMessage(model), train_results, test_pred[:, 0].tolist()
# %% import schnetpack.train as trn # set up metrics metrics = [ spk.metrics.MeanAbsoluteError('energy'), spk.metrics.MeanAbsoluteError('forces'), spk.metrics.RootMeanSquaredError('energy'), spk.metrics.RootMeanSquaredError('forces'), ] # construct hooks hooks = [ WandBHook(log_path=forcetut, metrics=metrics), trn.CSVHook(log_path=forcetut, metrics=metrics), trn.ReduceLROnPlateauHook(optimizer, patience=patience, factor=factor, min_lr=min_lr, stop_after_min=True) ] # %% trainer = trn.Trainer( model_path=forcetut, model=model, hooks=hooks, loss_fn=loss, optimizer=optimizer, train_loader=train_loader,