Example #1
0
    def __init__(
        self,
        model_path: str,
        checkpoint_path: str,
        mlp_params: dict = {},
    ):
        MLPCalc.__init__(self, mlp_params=mlp_params)

        self.ocp_describer = OCPDescriptor(
            model_path=model_path,
            checkpoint_path=checkpoint_path,
        )

        self.init_model()
Example #2
0
 def __init__(self,
              flare_params: dict,
              initial_images,
              mgp_model=None,
              par=False,
              use_mapping=False,
              **kwargs):
     self.initial_images = initial_images
     self.init_species_map()
     MLPCalc.__init__(self, mlp_params=flare_params)
     super().__init__(None,
                      mgp_model=mgp_model,
                      par=par,
                      use_mapping=use_mapping,
                      **kwargs)
Example #3
0
 def __init__(self, mlp_params, initial_images):
     MLPCalc.__init__(self, mlp_params=mlp_params)
     self.gp_model = None
     self.results = {}
     self.use_mapping = False
     self.mgp_model = None
     self.initial_images = initial_images
     self.init_species_map()
     self.update_gp_mode = self.mlp_params.get("update_gp_mode", "all")
     self.update_gp_range = self.mlp_params.get("update_gp_range", [])
     self.freeze_hyps = self.mlp_params.get("freeze_hyps", None)
     self.variance_type = self.mlp_params.get("variance_type", "SOR")
     self.opt_method = self.mlp_params.get("opt_method", "BFGS")
     self.kernel_type = self.mlp_params.get("kernel_type",
                                            "NormalizedDotProduct")
     self.iteration = 0
    def __init__(
        self,
        model_classes: "list[str]",
        model_paths: "list[str]",
        checkpoint_paths: "list[str]",
        mlp_params: dict = {},
    ) -> None:

        self.model_classes = model_classes
        self.model_paths = model_paths
        self.checkpoint_paths = checkpoint_paths

        self.finetuner_calcs = []
        for i in range(len(self.model_classes)):
            if isinstance(mlp_params, list):
                mlp_params_copy = copy.deepcopy(mlp_params[i])
            else:
                mlp_params_copy = copy.deepcopy(mlp_params)
            self.finetuner_calcs.append(
                FinetunerCalc(
                    model_name=self.model_classes[i],
                    model_path=self.model_paths[i],
                    checkpoint_path=self.checkpoint_paths[i],
                    mlp_params=mlp_params_copy,
                ))

        self.train_counter = 0
        self.ml_model = False
        if isinstance(mlp_params, list):
            mlp_params_copy = copy.deepcopy(mlp_params[0])
        else:
            mlp_params_copy = copy.deepcopy(mlp_params)
        if "tuner" not in mlp_params_copy:
            mlp_params_copy["tuner"] = {}
        self.ensemble_method = mlp_params_copy["tuner"].get(
            "ensemble_method", "mean")
        MLPCalc.__init__(self, mlp_params=mlp_params_copy)
Example #5
0
    def __init__(
        self,
        model_name: str,
        model_path: str,
        checkpoint_path: str,
        mlp_params: dict = {},
    ):

        if model_name not in ["gemnet", "spinconv", "dimenetpp"]:
            raise ValueError("Invalid model name provided")

        if "optimizer" in mlp_params.get("optim", {}):
            checkpoint = torch.load(checkpoint_path,
                                    map_location=torch.device("cpu"))
            for key in ["optimizer", "scheduler", "ema", "amp"]:
                if key in checkpoint and checkpoint[key] is not None:
                    raise ValueError(
                        str(checkpoint_path) + "\n^this checkpoint contains " +
                        str(key) +
                        " information, please load the .pt file, delete the " +
                        str(key) +
                        " dictionary, save it again as a .pt file, and try again so that the the given optimizer config will be loaded"
                    )

        self.model_name = model_name
        self.model_path = model_path
        self.checkpoint_path = checkpoint_path

        if "tuner" not in mlp_params:
            mlp_params["tuner"] = {}

        config = yaml.safe_load(open(self.model_path, "r"))
        if "includes" in config:
            for include in config["includes"]:
                # Change the path based on absolute path of config_yml
                path = os.path.join(
                    self.model_path.split("configs")[0], include)
                include_config = yaml.safe_load(open(path, "r"))
                config.update(include_config)
        if "optimizer" in mlp_params.get("optim", {}):
            config.pop("optim", None)
        config = merge_dict(config, mlp_params)

        MLPCalc.__init__(self, mlp_params=config)

        self.train_counter = 0
        self.ml_model = False
        self.max_neighbors = self.mlp_params["tuner"].get("max_neighbors", 50)
        self.cutoff = self.mlp_params["tuner"].get("cutoff", 6)
        self.energy_training = self.mlp_params["tuner"].get(
            "energy_training", False)
        if not self.energy_training:
            self.mlp_params["optim"]["energy_coefficient"] = 0
        if "num_threads" in self.mlp_params["tuner"]:
            torch.set_num_threads(self.mlp_params["tuner"]["num_threads"])
        self.validation_split = self.mlp_params["tuner"].get(
            "validation_split", None)

        # init block/weight freezing
        if self.model_name == "gemnet":
            self.unfreeze_blocks = ["out_blocks.3"]
        elif self.model_name == "spinconv":
            self.unfreeze_blocks = ["force_output_block"]
        elif self.model_name == "dimenetpp":
            self.unfreeze_blocks = ["output_blocks.3"]
        if "unfreeze_blocks" in self.mlp_params["tuner"]:
            if isinstance(self.mlp_params["tuner"]["unfreeze_blocks"], list):
                self.unfreeze_blocks = self.mlp_params["tuner"][
                    "unfreeze_blocks"]
            elif isinstance(self.mlp_params["tuner"]["unfreeze_blocks"], str):
                self.unfreeze_blocks = [
                    self.mlp_params["tuner"]["unfreeze_blocks"]
                ]
            else:
                raise ValueError("invalid unfreeze_blocks parameter given")

        # init trainer
        config_dict = copy.deepcopy(self.mlp_params)

        sys.stdout = open(os.devnull, "w")
        self.trainer = Trainer(
            config_yml=config_dict,
            checkpoint=self.checkpoint_path,
            cutoff=self.cutoff,
            max_neighbors=self.max_neighbors,
        )
        sys.stdout = sys.__stdout__
 def __init__(self, amptorch_trainer, n_ensembles):
     MLPCalc.__init__(self, mlp_params=amptorch_trainer.config)
     self.amptorch_trainer = amptorch_trainer
     self.n_ensembles = n_ensembles