def get_data_from_atoms(self, dataset): """ get train_loader object to replace for the ocp model trainer to train on """ a2g = AtomsToGraphs( max_neigh=self.max_neighbors, radius=self.cutoff, r_energy=True, r_forces=True, r_distances=True, r_edges=False, ) graphs_list = [a2g.convert(atoms) for atoms in dataset] for graph in graphs_list: graph.fid = 0 graph.sid = 0 graphs_list_dataset = GraphsListDataset(graphs_list) train_sampler = self.trainer.get_sampler( graphs_list_dataset, self.mlp_params.get("optim", {}).get("batch_size", 1), shuffle=False, ) self.trainer.train_sampler = train_sampler data_loader = self.trainer.get_dataloader( graphs_list_dataset, train_sampler, ) return data_loader
def process(self): print("### Preprocessing atoms objects from: {}".format( self.raw_file_names[0])) traj = Trajectory(self.raw_file_names[0]) a2g = AtomsToGraphs( max_neigh=self.config.get("max_neigh", 12), radius=self.config.get("radius", 6), dummy_distance=self.config.get("radius", 6) + 1, dummy_index=-1, r_energy=True, r_forces=True, r_distances=False, ) data_list = [] for atoms in tqdm( traj, desc="preprocessing atomic features", total=len(traj), unit="structure", ): data_list.append(a2g.convert(atoms)) self.data, self.slices = collate(data_list) torch.save((self.data, self.slices), self.processed_file_names[0])
class OCPCalculator(Calculator): implemented_properties = ["energy", "forces"] def __init__(self, trainer, pbc_graph=False): """ OCP-ASE Calculator Args: trainer: Object ML trainer for energy and force predictions. """ Calculator.__init__(self) self.trainer = trainer self.pbc_graph = pbc_graph self.a2g = AtomsToGraphs( max_neigh=50, radius=6, r_energy=False, r_forces=False, r_distances=False, ) def train(self): self.trainer.train() def load_pretrained(self, checkpoint_path): """ Load existing trained model Args: checkpoint_path: string Path to trained model """ try: self.trainer.load_pretrained(checkpoint_path) except NotImplementedError: print("Unable to load checkpoint!") def calculate(self, atoms, properties, system_changes): Calculator.calculate(self, atoms, properties, system_changes) data_object = self.a2g.convert(atoms) batch = data_list_collater([data_object]) if self.pbc_graph: edge_index, cell_offsets, neighbors = radius_graph_pbc( batch, 6, 50, batch.pos.device) batch.edge_index = edge_index batch.cell_offsets = cell_offsets batch.neighbors = neighbors predictions = self.trainer.predict(batch, per_image=False) if self.trainer.name == "s2ef": self.results["energy"] = predictions["energy"].item() self.results["forces"] = predictions["forces"].cpu().numpy() elif self.trainer.name == "is2re": self.results["energy"] = predictions["energy"].item()
class OCPCalculator(Calculator): implemented_properties = ["energy", "forces"] def __init__(self, trainer): """ OCP-ASE Calculator Args: trainer: Object ML trainer for energy and force predictions. """ Calculator.__init__(self) self.trainer = trainer self.a2g = AtomsToGraphs( max_neigh=12, radius=6, dummy_distance=7, dummy_index=-1, r_energy=False, r_forces=False, r_distances=False, ) def train(self): self.trainer.train() def load_pretrained(self, checkpoint_path): """ Load existing trained model Args: checkpoint_path: string Path to trained model """ try: self.trainer.load_pretrained(checkpoint_path) except NotImplementedError: print("Unable to load checkpoint!") def calculate(self, atoms, properties, system_changes): Calculator.calculate(self, atoms, properties, system_changes) data_object = self.a2g.convert(atoms) batch = Batch.from_data_list([data_object]) predictions = self.trainer.predict(batch) self.results["energy"] = predictions["energy"][0] self.results["forces"] = predictions["forces"][0]
class OCPCalculator(Calculator): implemented_properties = ["energy", "forces"] def __init__(self, config_yml, checkpoint=None, cutoff=6, max_neighbors=50): """ OCP-ASE Calculator Args: config_yml (str): Path to yaml config. checkpoint (str): Path to trained checkpoint. cutoff (int): Cutoff radius to be used for data preprocessing. max_neighbors (int): Maximum amount of neighbors to store for a given atom. """ setup_imports() setup_logging() Calculator.__init__(self) config = yaml.safe_load(open(config_yml, "r")) if "includes" in config: for include in config["includes"]: include_config = yaml.safe_load(open(include, "r")) config.update(include_config) # Save config so obj can be transported over network (pkl) self.config = copy.deepcopy(config) self.config["checkpoint"] = checkpoint self.trainer = registry.get_trainer_class( config.get("trainer", "simple"))( task=config["task"], model=config["model"], dataset=config["dataset"], optimizer=config["optim"], identifier="", slurm=config.get("slurm", {}), local_rank=config.get("local_rank", 0), is_debug=config.get("is_debug", True), cpu=True, ) if checkpoint is not None: self.load_checkpoint(checkpoint) self.a2g = AtomsToGraphs( max_neigh=max_neighbors, radius=cutoff, r_energy=False, r_forces=False, r_distances=False, ) def train(self): self.trainer.train() def load_checkpoint(self, checkpoint_path): """ Load existing trained model Args: checkpoint_path: string Path to trained model """ try: self.trainer.load_checkpoint(checkpoint_path) except NotImplementedError: logging.warning("Unable to load checkpoint!") def calculate(self, atoms, properties, system_changes): Calculator.calculate(self, atoms, properties, system_changes) data_object = self.a2g.convert(atoms) batch = data_list_collater([data_object]) predictions = self.trainer.predict(batch, per_image=False, disable_tqdm=True) if self.trainer.name == "s2ef": self.results["energy"] = predictions["energy"].item() self.results["forces"] = predictions["forces"].cpu().numpy() elif self.trainer.name == "is2re": self.results["energy"] = predictions["energy"].item()
class Trainer(ForcesTrainer): def __init__(self, config_yml=None, checkpoint=None, cutoff=6, max_neighbors=50): setup_imports() setup_logging() # Either the config path or the checkpoint path needs to be provided assert config_yml or checkpoint is not None if config_yml is not None: if isinstance(config_yml, str): config = yaml.safe_load(open(config_yml, "r")) if "includes" in config: for include in config["includes"]: # Change the path based on absolute path of config_yml path = os.path.join( config_yml.split("configs")[0], include) include_config = yaml.safe_load(open(path, "r")) config.update(include_config) else: config = config_yml # Only keeps the train data that might have normalizer values config["dataset"] = config["dataset"][0] else: # Loads the config from the checkpoint directly config = torch.load(checkpoint, map_location=torch.device("cpu"))["config"] # Load the trainer based on the dataset used if config["task"]["dataset"] == "trajectory_lmdb": config["trainer"] = "forces" else: config["trainer"] = "energy" config["model_attributes"]["name"] = config.pop("model") config["model"] = config["model_attributes"] # Calculate the edge indices on the fly config["model"]["otf_graph"] = True # Save config so obj can be transported over network (pkl) self.config = copy.deepcopy(config) self.config["checkpoint"] = checkpoint if "normalizer" not in config: del config["dataset"]["src"] config["normalizer"] = config["dataset"] super().__init__( task=config["task"], model=config["model"], dataset=None, optimizer=config["optim"], identifier="", normalizer=config["normalizer"], slurm=config.get("slurm", {}), local_rank=config.get("local_rank", 0), logger=config.get("logger", None), print_every=config.get("print_every", 1), is_debug=config.get("is_debug", True), cpu=True, ) if checkpoint is not None: try: self.load_checkpoint(checkpoint) except NotImplementedError: logging.warning("Unable to load checkpoint!") self.a2g = AtomsToGraphs( max_neigh=max_neighbors, radius=cutoff, r_energy=False, r_forces=False, r_distances=False, ) def get_atoms_prediction(self, atoms): data_object = self.a2g.convert(atoms) batch = data_list_collater([data_object]) predictions = self.predict(data_loader=batch, per_image=False, results_file=None, disable_tqdm=True) energy = predictions["energy"].item() forces = predictions["forces"].cpu().numpy() return energy, forces def train(self, disable_eval_tqdm=False): eval_every = self.config["optim"].get("eval_every", None) if eval_every is None: eval_every = len(self.train_loader) checkpoint_every = self.config["optim"].get("checkpoint_every", eval_every) primary_metric = self.config["task"].get( "primary_metric", self.evaluator.task_primary_metric[self.name]) self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0 self.metrics = {} # Calculate start_epoch from step instead of loading the epoch number # to prevent inconsistencies due to different batch size in checkpoint. start_epoch = self.step // len(self.train_loader) for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): self.train_sampler.set_epoch(epoch_int) skip_steps = self.step % len(self.train_loader) train_loader_iter = iter(self.train_loader) for i in range(skip_steps, len(self.train_loader)): self.epoch = epoch_int + (i + 1) / len(self.train_loader) self.step = epoch_int * len(self.train_loader) + i + 1 self.model.train() # Get a batch. batch = next(train_loader_iter) if self.config["optim"]["optimizer"] == "LBFGS": def closure(): self.optimizer.zero_grad() with torch.cuda.amp.autocast( enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss.backward() return loss self.optimizer.step(closure) self.optimizer.zero_grad() with torch.cuda.amp.autocast( enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) else: # Forward, loss, backward. with torch.cuda.amp.autocast( enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, self.metrics, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Log metrics. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update({ "lr": self.scheduler.get_lr(), "epoch": self.epoch, "step": self.step, }) if (self.step % self.config["cmd"]["print_every"] == 0 and distutils.is_master() and not self.is_hpo): log_str = [ "{}: {:.2e}".format(k, v) for k, v in log_dict.items() ] logging.info(", ".join(log_str)) self.metrics = {} if self.logger is not None: self.logger.log( log_dict, step=self.step, split="train", ) if checkpoint_every != -1 and self.step % checkpoint_every == 0: self.save(checkpoint_file="checkpoint.pt", training_state=True) # Evaluate on val set every `eval_every` iterations. if self.step % eval_every == 0: if self.val_loader is not None: val_metrics = self.validate( split="val", disable_tqdm=disable_eval_tqdm, ) self.update_best( primary_metric, val_metrics, disable_eval_tqdm=disable_eval_tqdm, ) if self.is_hpo: self.hpo_update( self.epoch, self.step, self.metrics, val_metrics, ) if self.config["task"].get("eval_relaxations", False): if "relax_dataset" not in self.config["task"]: logging.warning( "Cannot evaluate relaxations, relax_dataset not specified" ) else: self.run_relaxations() if self.config["optim"].get("print_loss_and_lr", False): print( "epoch: " + str(self.epoch) + ", \tstep: " + str(self.step) + ", \tloss: " + str(loss.detach().item()) + ", \tlr: " + str(self.scheduler.get_lr()) + ", \tval: " + str(val_metrics["loss"]["total"]) ) if self.step % eval_every == 0 and self.val_loader is not None else print( "epoch: " + str(self.epoch) + ", \tstep: " + str(self.step) + ", \tloss: " + str(loss.detach().item()) + ", \tlr: " + str(self.scheduler.get_lr())) if self.scheduler.scheduler_type == "ReduceLROnPlateau": if (self.step % eval_every == 0 and self.config["optim"].get( "scheduler_loss", None) == "train"): self.scheduler.step(metrics=loss.detach().item(), ) elif self.step % eval_every == 0 and self.val_loader is not None: self.scheduler.step( metrics=val_metrics[primary_metric]["metric"], ) else: self.scheduler.step() break_below_lr = (self.config["optim"].get( "break_below_lr", None) is not None) and ( self.scheduler.get_lr() < self.config["optim"]["break_below_lr"]) if break_below_lr: break if break_below_lr: break torch.cuda.empty_cache() if checkpoint_every == -1: self.save(checkpoint_file="checkpoint.pt", training_state=True) self.train_dataset.close_db() if "val_dataset" in self.config: self.val_dataset.close_db() if "test_dataset" in self.config: self.test_dataset.close_db()