示例#1
0
    def get_data_from_atoms(self, dataset):
        """
        get train_loader object to replace for the ocp model trainer to train on
        """
        a2g = AtomsToGraphs(
            max_neigh=self.max_neighbors,
            radius=self.cutoff,
            r_energy=True,
            r_forces=True,
            r_distances=True,
            r_edges=False,
        )

        graphs_list = [a2g.convert(atoms) for atoms in dataset]

        for graph in graphs_list:
            graph.fid = 0
            graph.sid = 0

        graphs_list_dataset = GraphsListDataset(graphs_list)

        train_sampler = self.trainer.get_sampler(
            graphs_list_dataset,
            self.mlp_params.get("optim", {}).get("batch_size", 1),
            shuffle=False,
        )
        self.trainer.train_sampler = train_sampler

        data_loader = self.trainer.get_dataloader(
            graphs_list_dataset,
            train_sampler,
        )

        return data_loader
示例#2
0
    def process(self):
        print("### Preprocessing atoms objects from:  {}".format(
            self.raw_file_names[0]))
        traj = Trajectory(self.raw_file_names[0])
        a2g = AtomsToGraphs(
            max_neigh=self.config.get("max_neigh", 12),
            radius=self.config.get("radius", 6),
            dummy_distance=self.config.get("radius", 6) + 1,
            dummy_index=-1,
            r_energy=True,
            r_forces=True,
            r_distances=False,
        )

        data_list = []

        for atoms in tqdm(
                traj,
                desc="preprocessing atomic features",
                total=len(traj),
                unit="structure",
        ):
            data_list.append(a2g.convert(atoms))

        self.data, self.slices = collate(data_list)
        torch.save((self.data, self.slices), self.processed_file_names[0])
示例#3
0
class OCPCalculator(Calculator):
    implemented_properties = ["energy", "forces"]

    def __init__(self, trainer, pbc_graph=False):
        """
        OCP-ASE Calculator

        Args:
            trainer: Object
                ML trainer for energy and force predictions.
        """
        Calculator.__init__(self)
        self.trainer = trainer
        self.pbc_graph = pbc_graph
        self.a2g = AtomsToGraphs(
            max_neigh=50,
            radius=6,
            r_energy=False,
            r_forces=False,
            r_distances=False,
        )

    def train(self):
        self.trainer.train()

    def load_pretrained(self, checkpoint_path):
        """
        Load existing trained model

        Args:
            checkpoint_path: string
                Path to trained model
        """
        try:
            self.trainer.load_pretrained(checkpoint_path)
        except NotImplementedError:
            print("Unable to load checkpoint!")

    def calculate(self, atoms, properties, system_changes):
        Calculator.calculate(self, atoms, properties, system_changes)
        data_object = self.a2g.convert(atoms)
        batch = data_list_collater([data_object])
        if self.pbc_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                batch, 6, 50, batch.pos.device)
            batch.edge_index = edge_index
            batch.cell_offsets = cell_offsets
            batch.neighbors = neighbors

        predictions = self.trainer.predict(batch, per_image=False)
        if self.trainer.name == "s2ef":
            self.results["energy"] = predictions["energy"].item()
            self.results["forces"] = predictions["forces"].cpu().numpy()

        elif self.trainer.name == "is2re":
            self.results["energy"] = predictions["energy"].item()
示例#4
0
class OCPCalculator(Calculator):
    implemented_properties = ["energy", "forces"]

    def __init__(self, trainer):
        """
        OCP-ASE Calculator

        Args:
            trainer: Object
                ML trainer for energy and force predictions.
        """
        Calculator.__init__(self)
        self.trainer = trainer
        self.a2g = AtomsToGraphs(
            max_neigh=12,
            radius=6,
            dummy_distance=7,
            dummy_index=-1,
            r_energy=False,
            r_forces=False,
            r_distances=False,
        )

    def train(self):
        self.trainer.train()

    def load_pretrained(self, checkpoint_path):
        """
        Load existing trained model

        Args:
            checkpoint_path: string
                Path to trained model
        """
        try:
            self.trainer.load_pretrained(checkpoint_path)
        except NotImplementedError:
            print("Unable to load checkpoint!")

    def calculate(self, atoms, properties, system_changes):
        Calculator.calculate(self, atoms, properties, system_changes)
        data_object = self.a2g.convert(atoms)
        batch = Batch.from_data_list([data_object])
        predictions = self.trainer.predict(batch)

        self.results["energy"] = predictions["energy"][0]
        self.results["forces"] = predictions["forces"][0]
示例#5
0
文件: ase_utils.py 项目: wood-b/ocp
class OCPCalculator(Calculator):
    implemented_properties = ["energy", "forces"]

    def __init__(self,
                 config_yml,
                 checkpoint=None,
                 cutoff=6,
                 max_neighbors=50):
        """
        OCP-ASE Calculator

        Args:
            config_yml (str):
                Path to yaml config.
            checkpoint (str):
                Path to trained checkpoint.
            cutoff (int):
                Cutoff radius to be used for data preprocessing.
            max_neighbors (int):
                Maximum amount of neighbors to store for a given atom.
        """
        setup_imports()
        setup_logging()
        Calculator.__init__(self)

        config = yaml.safe_load(open(config_yml, "r"))
        if "includes" in config:
            for include in config["includes"]:
                include_config = yaml.safe_load(open(include, "r"))
                config.update(include_config)

        # Save config so obj can be transported over network (pkl)
        self.config = copy.deepcopy(config)
        self.config["checkpoint"] = checkpoint

        self.trainer = registry.get_trainer_class(
            config.get("trainer", "simple"))(
                task=config["task"],
                model=config["model"],
                dataset=config["dataset"],
                optimizer=config["optim"],
                identifier="",
                slurm=config.get("slurm", {}),
                local_rank=config.get("local_rank", 0),
                is_debug=config.get("is_debug", True),
                cpu=True,
            )

        if checkpoint is not None:
            self.load_checkpoint(checkpoint)

        self.a2g = AtomsToGraphs(
            max_neigh=max_neighbors,
            radius=cutoff,
            r_energy=False,
            r_forces=False,
            r_distances=False,
        )

    def train(self):
        self.trainer.train()

    def load_checkpoint(self, checkpoint_path):
        """
        Load existing trained model

        Args:
            checkpoint_path: string
                Path to trained model
        """
        try:
            self.trainer.load_checkpoint(checkpoint_path)
        except NotImplementedError:
            logging.warning("Unable to load checkpoint!")

    def calculate(self, atoms, properties, system_changes):
        Calculator.calculate(self, atoms, properties, system_changes)
        data_object = self.a2g.convert(atoms)
        batch = data_list_collater([data_object])

        predictions = self.trainer.predict(batch,
                                           per_image=False,
                                           disable_tqdm=True)
        if self.trainer.name == "s2ef":
            self.results["energy"] = predictions["energy"].item()
            self.results["forces"] = predictions["forces"].cpu().numpy()

        elif self.trainer.name == "is2re":
            self.results["energy"] = predictions["energy"].item()
示例#6
0
class Trainer(ForcesTrainer):
    def __init__(self,
                 config_yml=None,
                 checkpoint=None,
                 cutoff=6,
                 max_neighbors=50):
        setup_imports()
        setup_logging()

        # Either the config path or the checkpoint path needs to be provided
        assert config_yml or checkpoint is not None

        if config_yml is not None:
            if isinstance(config_yml, str):
                config = yaml.safe_load(open(config_yml, "r"))

                if "includes" in config:
                    for include in config["includes"]:
                        # Change the path based on absolute path of config_yml
                        path = os.path.join(
                            config_yml.split("configs")[0], include)
                        include_config = yaml.safe_load(open(path, "r"))
                        config.update(include_config)
            else:
                config = config_yml
            # Only keeps the train data that might have normalizer values
            config["dataset"] = config["dataset"][0]
        else:
            # Loads the config from the checkpoint directly
            config = torch.load(checkpoint,
                                map_location=torch.device("cpu"))["config"]

            # Load the trainer based on the dataset used
            if config["task"]["dataset"] == "trajectory_lmdb":
                config["trainer"] = "forces"
            else:
                config["trainer"] = "energy"

            config["model_attributes"]["name"] = config.pop("model")
            config["model"] = config["model_attributes"]

        # Calculate the edge indices on the fly
        config["model"]["otf_graph"] = True

        # Save config so obj can be transported over network (pkl)
        self.config = copy.deepcopy(config)
        self.config["checkpoint"] = checkpoint

        if "normalizer" not in config:
            del config["dataset"]["src"]
            config["normalizer"] = config["dataset"]

        super().__init__(
            task=config["task"],
            model=config["model"],
            dataset=None,
            optimizer=config["optim"],
            identifier="",
            normalizer=config["normalizer"],
            slurm=config.get("slurm", {}),
            local_rank=config.get("local_rank", 0),
            logger=config.get("logger", None),
            print_every=config.get("print_every", 1),
            is_debug=config.get("is_debug", True),
            cpu=True,
        )

        if checkpoint is not None:
            try:
                self.load_checkpoint(checkpoint)
            except NotImplementedError:
                logging.warning("Unable to load checkpoint!")

        self.a2g = AtomsToGraphs(
            max_neigh=max_neighbors,
            radius=cutoff,
            r_energy=False,
            r_forces=False,
            r_distances=False,
        )

    def get_atoms_prediction(self, atoms):
        data_object = self.a2g.convert(atoms)
        batch = data_list_collater([data_object])
        predictions = self.predict(data_loader=batch,
                                   per_image=False,
                                   results_file=None,
                                   disable_tqdm=True)
        energy = predictions["energy"].item()
        forces = predictions["forces"].cpu().numpy()
        return energy, forces

    def train(self, disable_eval_tqdm=False):
        eval_every = self.config["optim"].get("eval_every", None)
        if eval_every is None:
            eval_every = len(self.train_loader)
        checkpoint_every = self.config["optim"].get("checkpoint_every",
                                                    eval_every)
        primary_metric = self.config["task"].get(
            "primary_metric", self.evaluator.task_primary_metric[self.name])
        self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0
        self.metrics = {}

        # Calculate start_epoch from step instead of loading the epoch number
        # to prevent inconsistencies due to different batch size in checkpoint.
        start_epoch = self.step // len(self.train_loader)

        for epoch_int in range(start_epoch,
                               self.config["optim"]["max_epochs"]):
            self.train_sampler.set_epoch(epoch_int)
            skip_steps = self.step % len(self.train_loader)
            train_loader_iter = iter(self.train_loader)

            for i in range(skip_steps, len(self.train_loader)):
                self.epoch = epoch_int + (i + 1) / len(self.train_loader)
                self.step = epoch_int * len(self.train_loader) + i + 1
                self.model.train()

                # Get a batch.
                batch = next(train_loader_iter)

                if self.config["optim"]["optimizer"] == "LBFGS":

                    def closure():
                        self.optimizer.zero_grad()
                        with torch.cuda.amp.autocast(
                                enabled=self.scaler is not None):
                            out = self._forward(batch)
                            loss = self._compute_loss(out, batch)
                        loss.backward()
                        return loss

                    self.optimizer.step(closure)

                    self.optimizer.zero_grad()
                    with torch.cuda.amp.autocast(
                            enabled=self.scaler is not None):
                        out = self._forward(batch)
                        loss = self._compute_loss(out, batch)

                else:
                    # Forward, loss, backward.
                    with torch.cuda.amp.autocast(
                            enabled=self.scaler is not None):
                        out = self._forward(batch)
                        loss = self._compute_loss(out, batch)
                    loss = self.scaler.scale(loss) if self.scaler else loss
                    self._backward(loss)

                scale = self.scaler.get_scale() if self.scaler else 1.0

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out,
                    batch,
                    self.evaluator,
                    self.metrics,
                )
                self.metrics = self.evaluator.update("loss",
                                                     loss.item() / scale,
                                                     self.metrics)

                # Log metrics.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update({
                    "lr": self.scheduler.get_lr(),
                    "epoch": self.epoch,
                    "step": self.step,
                })
                if (self.step % self.config["cmd"]["print_every"] == 0
                        and distutils.is_master() and not self.is_hpo):
                    log_str = [
                        "{}: {:.2e}".format(k, v) for k, v in log_dict.items()
                    ]
                    logging.info(", ".join(log_str))
                    self.metrics = {}

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=self.step,
                        split="train",
                    )

                if checkpoint_every != -1 and self.step % checkpoint_every == 0:
                    self.save(checkpoint_file="checkpoint.pt",
                              training_state=True)

                # Evaluate on val set every `eval_every` iterations.
                if self.step % eval_every == 0:
                    if self.val_loader is not None:
                        val_metrics = self.validate(
                            split="val",
                            disable_tqdm=disable_eval_tqdm,
                        )
                        self.update_best(
                            primary_metric,
                            val_metrics,
                            disable_eval_tqdm=disable_eval_tqdm,
                        )
                        if self.is_hpo:
                            self.hpo_update(
                                self.epoch,
                                self.step,
                                self.metrics,
                                val_metrics,
                            )

                    if self.config["task"].get("eval_relaxations", False):
                        if "relax_dataset" not in self.config["task"]:
                            logging.warning(
                                "Cannot evaluate relaxations, relax_dataset not specified"
                            )
                        else:
                            self.run_relaxations()

                if self.config["optim"].get("print_loss_and_lr", False):
                    print(
                        "epoch: " + str(self.epoch) + ", \tstep: " +
                        str(self.step) + ", \tloss: " +
                        str(loss.detach().item()) + ", \tlr: " +
                        str(self.scheduler.get_lr()) + ", \tval: " +
                        str(val_metrics["loss"]["total"])
                    ) if self.step % eval_every == 0 and self.val_loader is not None else print(
                        "epoch: " + str(self.epoch) + ", \tstep: " +
                        str(self.step) + ", \tloss: " +
                        str(loss.detach().item()) + ", \tlr: " +
                        str(self.scheduler.get_lr()))

                if self.scheduler.scheduler_type == "ReduceLROnPlateau":
                    if (self.step % eval_every == 0
                            and self.config["optim"].get(
                                "scheduler_loss", None) == "train"):
                        self.scheduler.step(metrics=loss.detach().item(), )
                    elif self.step % eval_every == 0 and self.val_loader is not None:
                        self.scheduler.step(
                            metrics=val_metrics[primary_metric]["metric"], )
                else:
                    self.scheduler.step()

                break_below_lr = (self.config["optim"].get(
                    "break_below_lr", None) is not None) and (
                        self.scheduler.get_lr() <
                        self.config["optim"]["break_below_lr"])
                if break_below_lr:
                    break
            if break_below_lr:
                break

            torch.cuda.empty_cache()

            if checkpoint_every == -1:
                self.save(checkpoint_file="checkpoint.pt", training_state=True)

        self.train_dataset.close_db()
        if "val_dataset" in self.config:
            self.val_dataset.close_db()
        if "test_dataset" in self.config:
            self.test_dataset.close_db()