示例#1
0
    def __init__(self,
                 config_yml=None,
                 checkpoint=None,
                 cutoff=6,
                 max_neighbors=50):
        setup_imports()
        setup_logging()

        # Either the config path or the checkpoint path needs to be provided
        assert config_yml or checkpoint is not None

        if config_yml is not None:
            if isinstance(config_yml, str):
                config = yaml.safe_load(open(config_yml, "r"))

                if "includes" in config:
                    for include in config["includes"]:
                        # Change the path based on absolute path of config_yml
                        path = os.path.join(
                            config_yml.split("configs")[0], include)
                        include_config = yaml.safe_load(open(path, "r"))
                        config.update(include_config)
            else:
                config = config_yml
            # Only keeps the train data that might have normalizer values
            config["dataset"] = config["dataset"][0]
        else:
            # Loads the config from the checkpoint directly
            config = torch.load(checkpoint,
                                map_location=torch.device("cpu"))["config"]

            # Load the trainer based on the dataset used
            if config["task"]["dataset"] == "trajectory_lmdb":
                config["trainer"] = "forces"
            else:
                config["trainer"] = "energy"

            config["model_attributes"]["name"] = config.pop("model")
            config["model"] = config["model_attributes"]

        # Calculate the edge indices on the fly
        config["model"]["otf_graph"] = True

        # Save config so obj can be transported over network (pkl)
        self.config = copy.deepcopy(config)
        self.config["checkpoint"] = checkpoint

        if "normalizer" not in config:
            del config["dataset"]["src"]
            config["normalizer"] = config["dataset"]

        super().__init__(
            task=config["task"],
            model=config["model"],
            dataset=None,
            optimizer=config["optim"],
            identifier="",
            normalizer=config["normalizer"],
            slurm=config.get("slurm", {}),
            local_rank=config.get("local_rank", 0),
            logger=config.get("logger", None),
            print_every=config.get("print_every", 1),
            is_debug=config.get("is_debug", True),
            cpu=True,
        )

        if checkpoint is not None:
            try:
                self.load_checkpoint(checkpoint)
            except NotImplementedError:
                logging.warning("Unable to load checkpoint!")

        self.a2g = AtomsToGraphs(
            max_neigh=max_neighbors,
            radius=cutoff,
            r_energy=False,
            r_forces=False,
            r_distances=False,
        )
示例#2
0
        default=1,
        help="No. of feature-extracting processes or no. of dataset chunks",
    )
    parser.add_argument("--ref-energy",
                        action="store_true",
                        help="Subtract reference energies")
    args = parser.parse_args()

    xyz_logs = glob.glob(os.path.join(args.data_path, "*.txt"))

    # Initialize feature extractor.
    a2g = AtomsToGraphs(
        max_neigh=50,
        radius=6,
        r_energy=True,
        r_forces=True,
        r_distances=False,
        r_fixed=True,
        r_edges=args.get_edges,
    )

    # Create output directory if it doesn't exist.
    os.makedirs(os.path.join(args.out_path), exist_ok=True)

    # Initialize lmdb paths
    db_paths = [
        os.path.join(args.out_path, "data.%04d.lmdb" % i)
        for i in range(args.num_workers)
    ]

    # Chunk the trajectories into args.num_workers splits
示例#3
0
class Trainer(ForcesTrainer):
    def __init__(self,
                 config_yml=None,
                 checkpoint=None,
                 cutoff=6,
                 max_neighbors=50):
        setup_imports()
        setup_logging()

        # Either the config path or the checkpoint path needs to be provided
        assert config_yml or checkpoint is not None

        if config_yml is not None:
            if isinstance(config_yml, str):
                config = yaml.safe_load(open(config_yml, "r"))

                if "includes" in config:
                    for include in config["includes"]:
                        # Change the path based on absolute path of config_yml
                        path = os.path.join(
                            config_yml.split("configs")[0], include)
                        include_config = yaml.safe_load(open(path, "r"))
                        config.update(include_config)
            else:
                config = config_yml
            # Only keeps the train data that might have normalizer values
            config["dataset"] = config["dataset"][0]
        else:
            # Loads the config from the checkpoint directly
            config = torch.load(checkpoint,
                                map_location=torch.device("cpu"))["config"]

            # Load the trainer based on the dataset used
            if config["task"]["dataset"] == "trajectory_lmdb":
                config["trainer"] = "forces"
            else:
                config["trainer"] = "energy"

            config["model_attributes"]["name"] = config.pop("model")
            config["model"] = config["model_attributes"]

        # Calculate the edge indices on the fly
        config["model"]["otf_graph"] = True

        # Save config so obj can be transported over network (pkl)
        self.config = copy.deepcopy(config)
        self.config["checkpoint"] = checkpoint

        if "normalizer" not in config:
            del config["dataset"]["src"]
            config["normalizer"] = config["dataset"]

        super().__init__(
            task=config["task"],
            model=config["model"],
            dataset=None,
            optimizer=config["optim"],
            identifier="",
            normalizer=config["normalizer"],
            slurm=config.get("slurm", {}),
            local_rank=config.get("local_rank", 0),
            logger=config.get("logger", None),
            print_every=config.get("print_every", 1),
            is_debug=config.get("is_debug", True),
            cpu=True,
        )

        if checkpoint is not None:
            try:
                self.load_checkpoint(checkpoint)
            except NotImplementedError:
                logging.warning("Unable to load checkpoint!")

        self.a2g = AtomsToGraphs(
            max_neigh=max_neighbors,
            radius=cutoff,
            r_energy=False,
            r_forces=False,
            r_distances=False,
        )

    def get_atoms_prediction(self, atoms):
        data_object = self.a2g.convert(atoms)
        batch = data_list_collater([data_object])
        predictions = self.predict(data_loader=batch,
                                   per_image=False,
                                   results_file=None,
                                   disable_tqdm=True)
        energy = predictions["energy"].item()
        forces = predictions["forces"].cpu().numpy()
        return energy, forces

    def train(self, disable_eval_tqdm=False):
        eval_every = self.config["optim"].get("eval_every", None)
        if eval_every is None:
            eval_every = len(self.train_loader)
        checkpoint_every = self.config["optim"].get("checkpoint_every",
                                                    eval_every)
        primary_metric = self.config["task"].get(
            "primary_metric", self.evaluator.task_primary_metric[self.name])
        self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0
        self.metrics = {}

        # Calculate start_epoch from step instead of loading the epoch number
        # to prevent inconsistencies due to different batch size in checkpoint.
        start_epoch = self.step // len(self.train_loader)

        for epoch_int in range(start_epoch,
                               self.config["optim"]["max_epochs"]):
            self.train_sampler.set_epoch(epoch_int)
            skip_steps = self.step % len(self.train_loader)
            train_loader_iter = iter(self.train_loader)

            for i in range(skip_steps, len(self.train_loader)):
                self.epoch = epoch_int + (i + 1) / len(self.train_loader)
                self.step = epoch_int * len(self.train_loader) + i + 1
                self.model.train()

                # Get a batch.
                batch = next(train_loader_iter)

                if self.config["optim"]["optimizer"] == "LBFGS":

                    def closure():
                        self.optimizer.zero_grad()
                        with torch.cuda.amp.autocast(
                                enabled=self.scaler is not None):
                            out = self._forward(batch)
                            loss = self._compute_loss(out, batch)
                        loss.backward()
                        return loss

                    self.optimizer.step(closure)

                    self.optimizer.zero_grad()
                    with torch.cuda.amp.autocast(
                            enabled=self.scaler is not None):
                        out = self._forward(batch)
                        loss = self._compute_loss(out, batch)

                else:
                    # Forward, loss, backward.
                    with torch.cuda.amp.autocast(
                            enabled=self.scaler is not None):
                        out = self._forward(batch)
                        loss = self._compute_loss(out, batch)
                    loss = self.scaler.scale(loss) if self.scaler else loss
                    self._backward(loss)

                scale = self.scaler.get_scale() if self.scaler else 1.0

                # Compute metrics.
                self.metrics = self._compute_metrics(
                    out,
                    batch,
                    self.evaluator,
                    self.metrics,
                )
                self.metrics = self.evaluator.update("loss",
                                                     loss.item() / scale,
                                                     self.metrics)

                # Log metrics.
                log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
                log_dict.update({
                    "lr": self.scheduler.get_lr(),
                    "epoch": self.epoch,
                    "step": self.step,
                })
                if (self.step % self.config["cmd"]["print_every"] == 0
                        and distutils.is_master() and not self.is_hpo):
                    log_str = [
                        "{}: {:.2e}".format(k, v) for k, v in log_dict.items()
                    ]
                    logging.info(", ".join(log_str))
                    self.metrics = {}

                if self.logger is not None:
                    self.logger.log(
                        log_dict,
                        step=self.step,
                        split="train",
                    )

                if checkpoint_every != -1 and self.step % checkpoint_every == 0:
                    self.save(checkpoint_file="checkpoint.pt",
                              training_state=True)

                # Evaluate on val set every `eval_every` iterations.
                if self.step % eval_every == 0:
                    if self.val_loader is not None:
                        val_metrics = self.validate(
                            split="val",
                            disable_tqdm=disable_eval_tqdm,
                        )
                        self.update_best(
                            primary_metric,
                            val_metrics,
                            disable_eval_tqdm=disable_eval_tqdm,
                        )
                        if self.is_hpo:
                            self.hpo_update(
                                self.epoch,
                                self.step,
                                self.metrics,
                                val_metrics,
                            )

                    if self.config["task"].get("eval_relaxations", False):
                        if "relax_dataset" not in self.config["task"]:
                            logging.warning(
                                "Cannot evaluate relaxations, relax_dataset not specified"
                            )
                        else:
                            self.run_relaxations()

                if self.config["optim"].get("print_loss_and_lr", False):
                    print(
                        "epoch: " + str(self.epoch) + ", \tstep: " +
                        str(self.step) + ", \tloss: " +
                        str(loss.detach().item()) + ", \tlr: " +
                        str(self.scheduler.get_lr()) + ", \tval: " +
                        str(val_metrics["loss"]["total"])
                    ) if self.step % eval_every == 0 and self.val_loader is not None else print(
                        "epoch: " + str(self.epoch) + ", \tstep: " +
                        str(self.step) + ", \tloss: " +
                        str(loss.detach().item()) + ", \tlr: " +
                        str(self.scheduler.get_lr()))

                if self.scheduler.scheduler_type == "ReduceLROnPlateau":
                    if (self.step % eval_every == 0
                            and self.config["optim"].get(
                                "scheduler_loss", None) == "train"):
                        self.scheduler.step(metrics=loss.detach().item(), )
                    elif self.step % eval_every == 0 and self.val_loader is not None:
                        self.scheduler.step(
                            metrics=val_metrics[primary_metric]["metric"], )
                else:
                    self.scheduler.step()

                break_below_lr = (self.config["optim"].get(
                    "break_below_lr", None) is not None) and (
                        self.scheduler.get_lr() <
                        self.config["optim"]["break_below_lr"])
                if break_below_lr:
                    break
            if break_below_lr:
                break

            torch.cuda.empty_cache()

            if checkpoint_every == -1:
                self.save(checkpoint_file="checkpoint.pt", training_state=True)

        self.train_dataset.close_db()
        if "val_dataset" in self.config:
            self.val_dataset.close_db()
        if "test_dataset" in self.config:
            self.test_dataset.close_db()
示例#4
0
        pdb.set_trace()

    # Read the adslab reference energies to subtract from potential energies.
    with open(args.adslab_ref, "rb") as f:
        adslab_ref = pickle.load(f)

    # Read tag information for each atom.
    with open(args.tags, "rb") as f:
        sysid_to_tags = pickle.load(f)

    # Initialize feature extractor.
    a2g = AtomsToGraphs(
        max_neigh=50,
        radius=6,
        r_energy=True,
        r_forces=True,
        r_distances=True,
        r_fixed=True,
    )

    # Create output directory if it doesn't exist.
    os.makedirs(os.path.join(args.out_path), exist_ok=True)

    # Initialize lmdb paths
    db_path = os.path.join(args.out_path, "data.lmdb")
    db = lmdb.open(
        db_path,
        map_size=1099511627776 * 2,
        subdir=False,
        meminit=False,
        map_async=True,
示例#5
0
    def __init__(
        self,
        model_path,
        checkpoint_path,
        dataset=None,
        a2g=None,
        task=None,
        identifier="active_learner_base_calc",
        seed=0,
        **kwargs,
    ):
        Calculator.__init__(self, **kwargs)

        self.model_path = model_path
        self.checkpoint_path = checkpoint_path
        self.dataset = dataset
        self.a2g = a2g
        self.task = task
        self.identifier = identifier
        self.kwargs = kwargs
        self.seed = seed

        self.model_dict = {}
        with open(model_path) as model_yaml:
            self.model_dict = yaml.safe_load(model_yaml)
        self.model_dict["optim"]["num_workers"] = 4
        # model_dict["model"]["freeze"] = False

        if not task:
            task = {
                "dataset": "trajectory_lmdb",  # dataset used for the S2EF task
                "description": "S2EF for active learning base calc",
                "type": "regression",
                "metric": "mae",
                "labels": ["potential energy"],
                "grad_input": "atomic forces",
                "train_on_free_atoms": True,
                "eval_on_free_atoms": True,
            }

        if not dataset:
            dataset = [{
                "src": "/home/jovyan/shared-datasets/OC20/s2ef/30k/train",
                "normalize_labels": False,
            }]

        if not a2g:
            a2g = AtomsToGraphs(
                max_neigh=50,
                radius=6,
                r_energy=True,
                r_forces=True,
                r_distances=False,
                r_edges=True,
                r_fixed=True,
            )
        self.a2g = a2g
        self.a2g_predict = copy.deepcopy(self.a2g)
        self.a2g_predict.r_forces = False
        self.a2g_predict.r_energy = False

        self.trainer = ForcesTrainer(
            task=task,
            model=self.model_dict["model"],
            dataset=dataset,
            optimizer=self.model_dict["optim"],
            identifier=identifier,
            is_debug=True,
            is_vis=False,
            cpu=True,
        )

        self.trainer.load_pretrained(checkpoint_path=checkpoint_path,
                                     ddp_to_dp=True)
示例#6
0
def oc20_initialize(model_name, gpu=True):
    """
    Initialize GNNP of OC20 (i.e. S2EF).
    Args:
        model_name (str): name of model for GNNP. One can use the followings,
            - "DimeNet++"
            - "GemNet-dT"
            - "CGCNN"
            - "SchNet"
            - "SpinConv"
        gpu (bool): using GPU, if possible.
    Returns:
        cutoff: cutoff radius.
    """

    setup_imports()
    setup_logging()

    # Check model_name
    log_file = open("log.oc20", "w")
    log_file.write("\n")
    log_file.write("model_name = " + model_name + "\n")

    if model_name is not None:
        model_name = model_name.lower()

    if model_name == "DimeNet++".lower():
        config_yml = "dimenetpp.yml"
        checkpoint = "dimenetpp_all.pt"

    elif model_name == "GemNet-dT".lower():
        config_yml = "gemnet.yml"
        checkpoint = "gemnet_t_direct_h512_all.pt"

    elif model_name == "CGCNN".lower():
        config_yml = "cgcnn.yml"
        checkpoint = "cgcnn_all.pt"

    elif model_name == "SchNet".lower():
        config_yml = "schnet.yml"
        checkpoint = "schnet_all_large.pt"

    elif model_name == "SpinConv".lower():
        config_yml = "spinconv.yml"
        checkpoint = "spinconv_force_centric_all.pt"

    else:
        raise Exception("incorrect model_name.")

    basePath = os.path.dirname(os.path.abspath(__file__))
    config_dir = os.path.normpath(os.path.join(basePath, "oc20_configs"))
    chekpt_dir = os.path.normpath(os.path.join(basePath, "oc20_checkpt"))
    config_yml = os.path.normpath(os.path.join(config_dir, config_yml))
    checkpoint = os.path.normpath(os.path.join(chekpt_dir, checkpoint))

    log_file.write("config_yml = " + config_yml + "\n")
    log_file.write("checkpoint = " + checkpoint + "\n")

    # Check gpu
    gpu_ = (gpu and torch.cuda.is_available())

    log_file.write("gpu (in)   = " + str(gpu) + "\n")
    log_file.write("gpu (eff)  = " + str(gpu_) + "\n")

    # Load configuration
    config = yaml.safe_load(open(config_yml, "r"))

    # Check max_neigh and cutoff
    max_neigh = config["model"].get("max_neighbors", 50)
    cutoff = config["model"].get("cutoff", 6.0)

    log_file.write("max_neigh  = " + str(max_neigh) + "\n")
    log_file.write("cutoff     = " + str(cutoff) + "\n")

    assert max_neigh > 0
    assert cutoff > 0.0

    # To calculate the edge indices on-the-fly
    config["model"]["otf_graph"] = True

    # Modify path of scale_file for GemNet-dT
    scale_file = config["model"].get("scale_file", None)

    if scale_file is not None:
        scale_file = os.path.normpath(os.path.join(config_dir, scale_file))
        config["model"]["scale_file"] = scale_file

    log_file.write("\nconfig:\n")
    log_file.write(pprint.pformat(config) + "\n")
    log_file.write("\n")
    log_file.close()

    # Create trainer, that is pre-trained
    global myTrainer

    myTrainer = registry.get_trainer_class(config.get("trainer", "forces"))(
        task=config["task"],
        model=config["model"],
        dataset=None,
        normalizer=config["normalizer"],
        optimizer=config["optim"],
        identifier="",
        slurm=config.get("slurm", {}),
        local_rank=config.get("local_rank", 0),
        is_debug=config.get("is_debug", True),
        cpu=not gpu_)

    # Load checkpoint
    myTrainer.load_checkpoint(checkpoint)

    # Atoms object of ASE, that is empty here
    global myAtoms

    myAtoms = None

    # Converter: Atoms -> Graphs (the edges on-the-fly)
    global myA2G

    myA2G = AtomsToGraphs(max_neigh=max_neigh,
                          radius=cutoff,
                          r_energy=False,
                          r_forces=False,
                          r_distances=False,
                          r_edges=False,
                          r_fixed=False)

    return cutoff
示例#7
0
    with open(os.path.join(args.traj_paths_txt), "r") as f:
        raw_traj_files = f.read().splitlines()
    num_trajectories = len(raw_traj_files)

    with open(os.path.join(args.adslab_ref), "rb") as g:
        adslab_ref = pickle.load(g)

    print("### Found %d trajectories in %s" %
          (num_trajectories, args.traj_paths_txt))

    # Initialize feature extractor.
    a2g = AtomsToGraphs(
        max_neigh=12,
        radius=6,
        dummy_distance=7,
        dummy_index=-1,
        r_energy=True,
        r_forces=True,
        r_distances=False,
        r_fixed=True,
    )

    # Create output directory if it doesn't exist.
    os.makedirs(os.path.join(args.out_path), exist_ok=True)

    # Initialize lmdb paths
    db_paths = [
        os.path.join(args.out_path, "data.%03d.lmdb" % i)
        for i in range(args.num_workers)
    ]

    # Chunk the trajectories into args.num_workers splits