def ocp_trainable(config, checkpoint_dir=None): setup_imports() # update config for PBT learning rate config["optim"].update(lr_initial=config["lr"]) # trainer defaults are changed to run HPO trainer = registry.get_trainer_class(config.get("trainer", "simple"))( task=config["task"], model=config["model"], dataset=config["dataset"], optimizer=config["optim"], identifier=config["identifier"], run_dir=config.get("run_dir", "./"), is_debug=config.get("is_debug", False), is_vis=config.get("is_vis", False), is_hpo=config.get("is_hpo", True), # hpo print_every=config.get("print_every", 10), seed=config.get("seed", 0), logger=config.get("logger", None), # hpo local_rank=config["local_rank"], amp=config.get("amp", False), cpu=config.get("cpu", False), ) # add checkpoint here if checkpoint_dir: checkpoint = os.path.join(checkpoint_dir, "checkpoint") trainer.load_pretrained(checkpoint) # set learning rate for g in trainer.optimizer.param_groups: g["lr"] = config["lr"] # start training trainer.train()
def main(config): if args.distributed: distutils.setup(config) try: setup_imports() trainer = registry.get_trainer_class(config.get("trainer", "simple"))( task=config["task"], model=config["model"], dataset=config["dataset"], optimizer=config["optim"], identifier=config["identifier"], run_dir=config.get("run_dir", "./"), is_debug=config.get("is_debug", False), is_vis=config.get("is_vis", False), print_every=config.get("print_every", 10), seed=config.get("seed", 0), logger=config.get("logger", "tensorboard"), local_rank=config["local_rank"], amp=config.get("amp", False), cpu=config.get("cpu", False), ) if config["checkpoint"] is not None: trainer.load_pretrained(config["checkpoint"]) start_time = time.time() if config["mode"] == "train": trainer.train() elif config["mode"] == "predict": assert ( trainer.test_loader is not None), "Test dataset is required for making predictions" assert config["checkpoint"] results_file = "predictions" trainer.predict( trainer.test_loader, results_file=results_file, disable_tqdm=False, ) elif config["mode"] == "run-relaxations": assert isinstance( trainer, ForcesTrainer ), "Relaxations are only possible for ForcesTrainer" assert (trainer.relax_dataset is not None ), "Relax dataset is required for making predictions" assert config["checkpoint"] trainer.run_relaxations() distutils.synchronize() if distutils.is_master(): print("Total time taken = ", time.time() - start_time) finally: if args.distributed: distutils.cleanup()
def __init__(self, config_yml, checkpoint=None, cutoff=6, max_neighbors=50): """ OCP-ASE Calculator Args: config_yml (str): Path to yaml config. checkpoint (str): Path to trained checkpoint. cutoff (int): Cutoff radius to be used for data preprocessing. max_neighbors (int): Maximum amount of neighbors to store for a given atom. """ setup_imports() setup_logging() Calculator.__init__(self) config = yaml.safe_load(open(config_yml, "r")) if "includes" in config: for include in config["includes"]: include_config = yaml.safe_load(open(include, "r")) config.update(include_config) # Save config so obj can be transported over network (pkl) self.config = copy.deepcopy(config) self.config["checkpoint"] = checkpoint self.trainer = registry.get_trainer_class( config.get("trainer", "simple"))( task=config["task"], model=config["model"], dataset=config["dataset"], optimizer=config["optim"], identifier="", slurm=config.get("slurm", {}), local_rank=config.get("local_rank", 0), is_debug=config.get("is_debug", True), cpu=True, ) if checkpoint is not None: self.load_checkpoint(checkpoint) self.a2g = AtomsToGraphs( max_neigh=max_neighbors, radius=cutoff, r_energy=False, r_forces=False, r_distances=False, )
def main(config): setup_imports() trainer = registry.get_trainer_class(config.get("trainer", "simple"))( task=config["task"], model=config["model"], dataset=config["dataset"], optimizer=config["optim"], identifier=config["identifier"], run_dir=config.get("run_dir", "./"), is_debug=config.get("is_debug", False), is_vis=config.get("is_vis", False), print_every=config.get("print_every", 10), seed=config.get("seed", 0), logger=config.get("logger", "tensorboard"), local_rank=config["local_rank"], amp=config.get("amp", False), ) import time start_time = time.time() trainer.train() distutils.synchronize() print('Time = ', time.time() - start_time)
def __call__(self, config): setup_logging() self.config = copy.deepcopy(config) if args.distributed: distutils.setup(config) try: setup_imports() self.trainer = registry.get_trainer_class( config.get("trainer", "simple"))( task=config["task"], model=config["model"], dataset=config["dataset"], optimizer=config["optim"], identifier=config["identifier"], timestamp_id=config.get("timestamp_id", None), run_dir=config.get("run_dir", "./"), is_debug=config.get("is_debug", False), is_vis=config.get("is_vis", False), print_every=config.get("print_every", 10), seed=config.get("seed", 0), logger=config.get("logger", "tensorboard"), local_rank=config["local_rank"], amp=config.get("amp", False), cpu=config.get("cpu", False), slurm=config.get("slurm", {}), ) self.task = registry.get_task_class(config["mode"])(self.config) self.task.setup(self.trainer) start_time = time.time() self.task.run() distutils.synchronize() if distutils.is_master(): logging.info(f"Total time taken: {time.time() - start_time}") finally: if args.distributed: distutils.cleanup()
if __name__ == "__main__": setup_logging() num_batches = 16 # number of batches to use to fit a single variable parser = flags.get_parser() args, override_args = parser.parse_known_args() config = build_config(args, override_args) assert config["model"]["name"].startswith("gemnet") config["logger"] = "tensorboard" if args.distributed: raise ValueError( "I don't think this works with DDP (race conditions).") setup_imports() scale_file = config["model"]["scale_file"] logging.info(f"Run fitting for model: {args.identifier}") logging.info(f"Target scale file: {scale_file}") def initialize_scale_file(scale_file): # initialize file preset = {"comment": args.identifier} write_json(scale_file, preset) if os.path.exists(scale_file): logging.warning(f"Already found existing file: {scale_file}") flag = input( "Do you want to continue and overwrite the file (1), "
def __init__(self, config_yml=None, checkpoint=None, cutoff=6, max_neighbors=50): setup_imports() setup_logging() # Either the config path or the checkpoint path needs to be provided assert config_yml or checkpoint is not None if config_yml is not None: if isinstance(config_yml, str): config = yaml.safe_load(open(config_yml, "r")) if "includes" in config: for include in config["includes"]: # Change the path based on absolute path of config_yml path = os.path.join( config_yml.split("configs")[0], include) include_config = yaml.safe_load(open(path, "r")) config.update(include_config) else: config = config_yml # Only keeps the train data that might have normalizer values config["dataset"] = config["dataset"][0] else: # Loads the config from the checkpoint directly config = torch.load(checkpoint, map_location=torch.device("cpu"))["config"] # Load the trainer based on the dataset used if config["task"]["dataset"] == "trajectory_lmdb": config["trainer"] = "forces" else: config["trainer"] = "energy" config["model_attributes"]["name"] = config.pop("model") config["model"] = config["model_attributes"] # Calculate the edge indices on the fly config["model"]["otf_graph"] = True # Save config so obj can be transported over network (pkl) self.config = copy.deepcopy(config) self.config["checkpoint"] = checkpoint if "normalizer" not in config: del config["dataset"]["src"] config["normalizer"] = config["dataset"] super().__init__( task=config["task"], model=config["model"], dataset=None, optimizer=config["optim"], identifier="", normalizer=config["normalizer"], slurm=config.get("slurm", {}), local_rank=config.get("local_rank", 0), logger=config.get("logger", None), print_every=config.get("print_every", 1), is_debug=config.get("is_debug", True), cpu=True, ) if checkpoint is not None: try: self.load_checkpoint(checkpoint) except NotImplementedError: logging.warning("Unable to load checkpoint!") self.a2g = AtomsToGraphs( max_neigh=max_neighbors, radius=cutoff, r_energy=False, r_forces=False, r_distances=False, )
def oc20_initialize(model_name, gpu=True): """ Initialize GNNP of OC20 (i.e. S2EF). Args: model_name (str): name of model for GNNP. One can use the followings, - "DimeNet++" - "GemNet-dT" - "CGCNN" - "SchNet" - "SpinConv" gpu (bool): using GPU, if possible. Returns: cutoff: cutoff radius. """ setup_imports() setup_logging() # Check model_name log_file = open("log.oc20", "w") log_file.write("\n") log_file.write("model_name = " + model_name + "\n") if model_name is not None: model_name = model_name.lower() if model_name == "DimeNet++".lower(): config_yml = "dimenetpp.yml" checkpoint = "dimenetpp_all.pt" elif model_name == "GemNet-dT".lower(): config_yml = "gemnet.yml" checkpoint = "gemnet_t_direct_h512_all.pt" elif model_name == "CGCNN".lower(): config_yml = "cgcnn.yml" checkpoint = "cgcnn_all.pt" elif model_name == "SchNet".lower(): config_yml = "schnet.yml" checkpoint = "schnet_all_large.pt" elif model_name == "SpinConv".lower(): config_yml = "spinconv.yml" checkpoint = "spinconv_force_centric_all.pt" else: raise Exception("incorrect model_name.") basePath = os.path.dirname(os.path.abspath(__file__)) config_dir = os.path.normpath(os.path.join(basePath, "oc20_configs")) chekpt_dir = os.path.normpath(os.path.join(basePath, "oc20_checkpt")) config_yml = os.path.normpath(os.path.join(config_dir, config_yml)) checkpoint = os.path.normpath(os.path.join(chekpt_dir, checkpoint)) log_file.write("config_yml = " + config_yml + "\n") log_file.write("checkpoint = " + checkpoint + "\n") # Check gpu gpu_ = (gpu and torch.cuda.is_available()) log_file.write("gpu (in) = " + str(gpu) + "\n") log_file.write("gpu (eff) = " + str(gpu_) + "\n") # Load configuration config = yaml.safe_load(open(config_yml, "r")) # Check max_neigh and cutoff max_neigh = config["model"].get("max_neighbors", 50) cutoff = config["model"].get("cutoff", 6.0) log_file.write("max_neigh = " + str(max_neigh) + "\n") log_file.write("cutoff = " + str(cutoff) + "\n") assert max_neigh > 0 assert cutoff > 0.0 # To calculate the edge indices on-the-fly config["model"]["otf_graph"] = True # Modify path of scale_file for GemNet-dT scale_file = config["model"].get("scale_file", None) if scale_file is not None: scale_file = os.path.normpath(os.path.join(config_dir, scale_file)) config["model"]["scale_file"] = scale_file log_file.write("\nconfig:\n") log_file.write(pprint.pformat(config) + "\n") log_file.write("\n") log_file.close() # Create trainer, that is pre-trained global myTrainer myTrainer = registry.get_trainer_class(config.get("trainer", "forces"))( task=config["task"], model=config["model"], dataset=None, normalizer=config["normalizer"], optimizer=config["optim"], identifier="", slurm=config.get("slurm", {}), local_rank=config.get("local_rank", 0), is_debug=config.get("is_debug", True), cpu=not gpu_) # Load checkpoint myTrainer.load_checkpoint(checkpoint) # Atoms object of ASE, that is empty here global myAtoms myAtoms = None # Converter: Atoms -> Graphs (the edges on-the-fly) global myA2G myA2G = AtomsToGraphs(max_neigh=max_neigh, radius=cutoff, r_energy=False, r_forces=False, r_distances=False, r_edges=False, r_fixed=False) return cutoff