def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task="is2re"), {} loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm(enumerate(loader), total=len(loader)): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics
def validate(self, split="val", epoch=None, disable_tqdm=False): if distutils.is_master(): print("### Evaluating on {}.".format(split)) if self.is_hpo: disable_tqdm = True self.model.eval() evaluator, metrics = Evaluator(task=self.name), {} rank = distutils.get_rank() loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics
def load_evaluator_is2re(request): request.cls.evaluator = Evaluator(task="is2re") prediction = { "energy": torch.randn(50), } target = { "energy": torch.randn(50), } request.cls.metrics = request.cls.evaluator.eval(prediction, target)
def load_evaluator_is2rs(request): request.cls.evaluator = Evaluator(task="is2rs") prediction = { "positions": torch.randn(50, 3), } target = { "positions": torch.randn(50, 3), } request.cls.metrics = request.cls.evaluator.eval(prediction, target)
def load_evaluator_s2ef(request): request.cls.evaluator = Evaluator(task="s2ef") prediction = { "energy": torch.randn(6), "forces": torch.randn(1000000, 3), "natoms": torch.tensor( (100000, 200000, 300000, 200000, 100000, 100000)), } target = { "energy": torch.randn(6), "forces": torch.randn(1000000, 3), "natoms": torch.tensor( (100000, 200000, 300000, 200000, 100000, 100000)), } request.cls.metrics = request.cls.evaluator.eval(prediction, target)
def load_evaluator_is2rs(request): request.cls.evaluator = Evaluator(task="is2rs") prediction = { "positions": torch.randn(50, 3), "natoms": torch.tensor((5, 5, 10, 12, 18)), "cell": torch.randn(5, 3, 3), "pbc": torch.tensor([True, True, True]), } target = { "positions": torch.randn(50, 3), "cell": torch.randn(5, 3, 3), "natoms": torch.tensor((5, 5, 10, 12, 18)), "pbc": torch.tensor([True, True, True]), } request.cls.metrics = request.cls.evaluator.eval(prediction, target)
def run_relaxations(self, split="val", epoch=None): print("### Running ML-relaxations") self.model.eval() evaluator, metrics = Evaluator(task="is2rs"), {} if hasattr(self.relax_dataset[0], "pos_relaxed") and hasattr( self.relax_dataset[0], "y_relaxed"): split = "val" else: split = "test" ids = [] relaxed_positions = [] for i, batch in tqdm(enumerate(self.relax_loader), total=len(self.relax_loader)): relaxed_batch = ml_relax( batch=batch, model=self, steps=self.config["task"].get("relaxation_steps", 200), fmax=self.config["task"].get("relaxation_fmax", 0.0), relax_opt=self.config["task"]["relax_opt"], device=self.device, transform=None, ) if self.config["task"].get("write_pos", False): systemids = [str(i) for i in relaxed_batch.sid.tolist()] natoms = relaxed_batch.natoms.tolist() positions = torch.split(relaxed_batch.pos, natoms) batch_relaxed_positions = [pos.tolist() for pos in positions] relaxed_positions += batch_relaxed_positions ids += systemids if split == "val": mask = relaxed_batch.fixed == 0 s_idx = 0 natoms_free = [] for natoms in relaxed_batch.natoms: natoms_free.append( torch.sum(mask[s_idx:s_idx + natoms]).item()) s_idx += natoms target = { "energy": relaxed_batch.y_relaxed, "positions": relaxed_batch.pos_relaxed[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } prediction = { "energy": relaxed_batch.y, "positions": relaxed_batch.pos[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } metrics = evaluator.eval(prediction, target, metrics) if self.config["task"].get("write_pos", False): rank = distutils.get_rank() pos_filename = os.path.join(self.config["cmd"]["results_dir"], f"relaxed_pos_{rank}.npz") np.savez_compressed( pos_filename, ids=ids, pos=np.array(relaxed_positions, dtype=object), ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"relaxed_pos_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) gather_results["pos"].extend(rank_results["pos"]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] gather_results["pos"] = np.array(gather_results["pos"], dtype=object)[idx] print(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results) if split == "val": aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = ( aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics # Make plots. log_dict = {k: metrics[k]["metric"] for k in metrics} if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) if distutils.is_master(): print(metrics)
def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, name="base_trainer", ): self.name = name if torch.cuda.is_available(): self.device = local_rank else: self.device = "cpu" if run_dir is None: run_dir = os.getcwd() run_dir = Path(run_dir) timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp(timestamp).strftime( "%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": str(run_dir / "checkpoints" / timestamp), "results_dir": str(run_dir / "results" / timestamp), "logs_dir": str(run_dir / "logs" / logger / timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name)
def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, is_hpo=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, cpu=False, name="base_trainer", ): self.name = name self.cpu = cpu self.start_step = 0 if torch.cuda.is_available() and not self.cpu: self.device = local_rank else: self.device = "cpu" self.cpu = True # handle case when `--cpu` isn't specified # but there are no gpu devices available if run_dir is None: run_dir = os.getcwd() timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp( timestamp.int()).strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) try: commit_hash = (subprocess.check_output([ "git", "-C", ocpmodels.__path__[0], "describe", "--always", ]).strip().decode("ascii")) # catch instances where code is not being run from a git repo except Exception: commit_hash = None self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "gpus": distutils.get_world_size() if not self.cpu else 0, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "commit": commit_hash, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master() and not is_hpo: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis self.is_hpo = is_hpo if self.is_hpo: # sets the hpo checkpoint frequency # default is no checkpointing self.hpo_checkpoint_every = self.config["optim"].get( "checkpoint_every", -1) if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name)
def __init__( self, task, model, dataset, optimizer, identifier, normalizer=None, timestamp_id=None, run_dir=None, is_debug=False, is_vis=False, is_hpo=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, cpu=False, name="base_trainer", slurm={}, ): self.name = name self.cpu = cpu self.epoch = 0 self.step = 0 if torch.cuda.is_available() and not self.cpu: self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") self.cpu = True # handle case when `--cpu` isn't specified # but there are no gpu devices available if run_dir is None: run_dir = os.getcwd() if timestamp_id is None: timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device ) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp( timestamp.int() ).strftime("%Y-%m-%d-%H-%M-%S") if identifier: self.timestamp_id = f"{timestamp}-{identifier}" else: self.timestamp_id = timestamp else: self.timestamp_id = timestamp_id try: commit_hash = ( subprocess.check_output( [ "git", "-C", ocpmodels.__path__[0], "describe", "--always", ] ) .strip() .decode("ascii") ) # catch instances where code is not being run from a git repo except Exception: commit_hash = None self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "gpus": distutils.get_world_size() if not self.cpu else 0, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp_id": self.timestamp_id, "commit": commit_hash, "checkpoint_dir": os.path.join( run_dir, "checkpoints", self.timestamp_id ), "results_dir": os.path.join( run_dir, "results", self.timestamp_id ), "logs_dir": os.path.join( run_dir, "logs", logger, self.timestamp_id ), }, "slurm": slurm, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if "SLURM_JOB_ID" in os.environ and "folder" in self.config["slurm"]: self.config["slurm"]["job_id"] = os.environ["SLURM_JOB_ID"] self.config["slurm"]["folder"] = self.config["slurm"][ "folder" ].replace("%j", self.config["slurm"]["job_id"]) if isinstance(dataset, list): if len(dataset) > 0: self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] elif isinstance(dataset, dict): self.config["dataset"] = dataset.get("train", None) self.config["val_dataset"] = dataset.get("val", None) self.config["test_dataset"] = dataset.get("test", None) else: self.config["dataset"] = dataset self.normalizer = normalizer # This supports the legacy way of providing norm parameters in dataset if self.config.get("dataset", None) is not None and normalizer is None: self.normalizer = self.config["dataset"] if not is_debug and distutils.is_master() and not is_hpo: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis self.is_hpo = is_hpo if self.is_hpo: # conditional import is necessary for checkpointing from ray import tune from ocpmodels.common.hpo_utils import tune_reporter # sets the hpo checkpoint frequency # default is no checkpointing self.hpo_checkpoint_every = self.config["optim"].get( "checkpoint_every", -1 ) if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name)
def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, "amp": amp, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] else: self.config["dataset"] = dataset if not is_debug: os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis if torch.cuda.is_available(): device_ids = list(range(torch.cuda.device_count())) self.output_device = self.config["optim"].get( "output_device", device_ids[0]) self.device = f"cuda:{self.output_device}" else: self.device = "cpu" print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task="is2re")
class EnergyTrainer(BaseTrainer): def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, "amp": amp, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] else: self.config["dataset"] = dataset if not is_debug: os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis if torch.cuda.is_available(): device_ids = list(range(torch.cuda.device_count())) self.output_device = self.config["optim"].get( "output_device", device_ids[0]) self.device = f"cuda:{self.output_device}" else: self.device = "cpu" print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task="is2re") def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater(self.config["optim"].get( "num_gpus", 1)) if self.config["task"]["dataset"] == "single_point_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], shuffle=True, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) self.val_loader = self.test_loader = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) else: raise NotImplementedError self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: raise NotImplementedError def load_model(self): super(EnergyTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.output_device, num_gpus=self.config["optim"].get("num_gpus", 1), ) self.model.to(self.device) def train(self): self.best_val_mae = 1e9 for epoch in range(self.config["optim"]["max_epochs"]): self.model.train() for i, batch in enumerate(self.train_loader): # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, metrics={}, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)}) if i % self.config["cmd"]["print_every"] == 0: log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) self.scheduler.step() torch.cuda.empty_cache() if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator.task_primary_metric["is2re"]] ["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator.task_primary_metric["is2re"]]["metric"] if not self.is_debug: save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) if self.test_loader is not None: self.validate(split="test", epoch=epoch) def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task="is2re"), {} loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm(enumerate(loader), total=len(loader)): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics def _forward(self, batch_list): output = self.model(batch_list) if output.shape[-1] == 1: output = output.view(-1) return { "energy": output, } def _compute_loss(self, out, batch_list): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): target_normed = self.normalizers["target"].norm(energy_target) else: target_normed = energy_target loss = self.criterion(out["energy"], target_normed) return loss def _compute_metrics(self, out, batch_list, evaluator, metrics={}): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): out["energy"] = self.normalizers["target"].denorm(out["energy"]) metrics = evaluator.eval( out, {"energy": energy_target}, prev_metrics=metrics, ) return metrics
class DistributedForcesTrainer(BaseTrainer): def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis if torch.cuda.is_available(): self.device = local_rank else: self.device = "cpu" if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task="s2ef") def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater(1) if self.config["task"]["dataset"] == "trajectory_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], shuffle=True, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) self.val_loader = self.test_loader = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) else: self.dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) ( self.train_loader, self.val_loader, self.test_loader, ) = self.dataset.get_dataloaders( batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, ) self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: self.normalizers["target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if self.config["model_attributes"].get("regress_forces", True): if self.config["dataset"].get("normalize_labels", True): if "grad_target_mean" in self.config["dataset"]: self.normalizers["grad_target"] = Normalizer( mean=self.config["dataset"]["grad_target_mean"], std=self.config["dataset"]["grad_target_std"], device=self.device, ) else: self.normalizers["grad_target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) self.normalizers["grad_target"].mean.fill_(0) if (self.is_vis and self.config["task"]["dataset"] != "qm9" and distutils.is_master()): # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots) def load_model(self): super(DistributedForcesTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1, ) self.model = DistributedDataParallel(self.model, device_ids=[self.device], find_unused_parameters=True) # Takes in a new data source and generates predictions on it. def predict(self, dataset, batch_size=32): if isinstance(dataset, dict): if self.config["task"]["dataset"] == "trajectory_lmdb": print("### Generating predictions on {}.".format( dataset["src"])) else: print("### Generating predictions on {}.".format( dataset["src"] + dataset["traj"])) dataset = registry.get_dataset_class( self.config["task"]["dataset"])(dataset) data_loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=self.parallel_collater, ) elif isinstance(dataset, torch_geometric.data.Batch): data_loader = [[dataset]] else: raise NotImplementedError self.model.eval() predictions = {"energy": [], "forces": []} for i, batch_list in enumerate(data_loader): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch_list) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) out["forces"] = self.normalizers["grad_target"].denorm( out["forces"]) atoms_sum = 0 predictions["energy"].extend(out["energy"].tolist()) batch_natoms = torch.cat([batch.natoms for batch in batch_list]) for natoms in batch_natoms: predictions["forces"].append( out["forces"][atoms_sum:natoms + atoms_sum].cpu().detach().numpy()) atoms_sum += natoms return predictions def train(self): self.best_val_mae = 1e9 eval_every = self.config["optim"].get("eval_every", -1) iters = 0 self.metrics = {} for epoch in range(self.config["optim"]["max_epochs"]): self.model.train() for i, batch in enumerate(self.train_loader): # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, self.metrics, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)}) if i % self.config["cmd"]["print_every"] == 0: log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) self.metrics = {} if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) iters += 1 # Evaluate on val set every `eval_every` iterations. if eval_every != -1 and iters % eval_every == 0: if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator. task_primary_metric["s2ef"]]["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator. task_primary_metric["s2ef"]]["metric"] if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch + (i + 1) / len(self.train_loader), "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, }, self.config["cmd"]["checkpoint_dir"], ) self.scheduler.step() torch.cuda.empty_cache() if eval_every == -1: if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator.task_primary_metric["s2ef"]] ["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator. task_primary_metric["s2ef"]]["metric"] if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, }, self.config["cmd"]["checkpoint_dir"], ) if self.test_loader is not None: self.validate(split="test", epoch=epoch) if ("relaxation_dir" in self.config["task"] and self.config["task"].get("ml_relax", "end") == "train"): self.validate_relaxation( split="val", epoch=epoch, ) if ("relaxation_dir" in self.config["task"] and self.config["task"].get("ml_relax", "end") == "end"): self.validate_relaxation( split="val", epoch=epoch, ) def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task="s2ef"), {} loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm(enumerate(loader), total=len(loader)): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics def validate_relaxation(self, split="val", epoch=None): print("### Evaluating ML-relaxation") self.model.eval() mae_energy, mae_structure = relax_eval( trainer=self, traj_dir=self.config["task"]["relaxation_dir"], metric=self.config["task"]["metric"], steps=self.config["task"].get("relaxation_steps", 300), fmax=self.config["task"].get("relaxation_fmax", 0.01), results_dir=self.config["cmd"]["results_dir"], ) mae_energy = distutils.all_reduce(mae_energy, average=True, device=self.device) mae_structure = distutils.all_reduce(mae_structure, average=True, device=self.device) log_dict = { "relaxed_energy_mae": mae_energy, "relaxed_structure_mae": mae_structure, "epoch": epoch + 1, } # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(log_dict) return mae_energy, mae_structure def _forward(self, batch_list): # forward pass. if self.config["model_attributes"].get("regress_forces", True): out_energy, out_forces = self.model(batch_list) else: out_energy = self.model(batch_list) if out_energy.shape[-1] == 1: out_energy = out_energy.view(-1) out = { "energy": out_energy, } if self.config["model_attributes"].get("regress_forces", True): out["forces"] = out_forces return out def _compute_loss(self, out, batch_list): loss = [] # Energy loss. energy_target = torch.cat( [batch.y.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): energy_target = self.normalizers["target"].norm(energy_target) energy_mult = self.config["optim"].get("energy_coefficient", 1) loss.append(energy_mult * self.criterion(out["energy"], energy_target)) # Force loss. if self.config["model_attributes"].get("regress_forces", True): force_target = torch.cat( [batch.force.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): force_target = self.normalizers["grad_target"].norm( force_target) # Force coefficient = 30 has been working well for us. force_mult = self.config["optim"].get("force_coefficient", 30) if self.config["task"].get("train_on_free_atoms", False): fixed = torch.cat( [batch.fixed.to(self.device) for batch in batch_list]) mask = fixed == 0 loss.append( force_mult * self.criterion(out["forces"][mask], force_target[mask])) else: loss.append(force_mult * self.criterion(out["forces"], force_target)) # Sanity check to make sure the compute graph is correct. for lc in loss: assert hasattr(lc, "grad_fn") loss = sum(loss) return loss def _compute_metrics(self, out, batch_list, evaluator, metrics={}): target = { "energy": torch.cat([batch.y.to(self.device) for batch in batch_list], dim=0), "forces": torch.cat([batch.force.to(self.device) for batch in batch_list], dim=0), } if self.config["task"].get("eval_on_free_atoms", True): fixed = torch.cat( [batch.fixed.to(self.device) for batch in batch_list]) mask = fixed == 0 out["forces"] = out["forces"][mask] target["forces"] = target["forces"][mask] if self.config["dataset"].get("normalize_labels", True): out["energy"] = self.normalizers["target"].denorm(out["energy"]) out["forces"] = self.normalizers["grad_target"].denorm( out["forces"]) metrics = evaluator.eval(out, target, prev_metrics=metrics) return metrics
class EnergyTrainer(BaseTrainer): """ Trainer class for the Initial Structure to Relaxed Energy (IS2RE) task. .. note:: Examples of configurations for task, model, dataset and optimizer can be found in `configs/ocp_is2re <https://github.com/Open-Catalyst-Project/baselines/tree/master/configs/ocp_is2re/>`_. Args: task (dict): Task configuration. model (dict): Model configuration. dataset (dict): Dataset configuration. The dataset needs to be a SinglePointLMDB dataset. optimizer (dict): Optimizer configuration. identifier (str): Experiment identifier that is appended to log directory. run_dir (str, optional): Path to the run directory where logs are to be saved. (default: :obj:`None`) is_debug (bool, optional): Run in debug mode. (default: :obj:`False`) is_vis (bool, optional): Run in debug mode. (default: :obj:`False`) print_every (int, optional): Frequency of printing logs. (default: :obj:`100`) seed (int, optional): Random number seed. (default: :obj:`None`) logger (str, optional): Type of logger to be used. (default: :obj:`tensorboard`) local_rank (int, optional): Local rank of the process, only applicable for distributed training. (default: :obj:`0`) amp (bool, optional): Run using automatic mixed precision. (default: :obj:`False`) """ def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, "amp": amp, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis if torch.cuda.is_available(): self.device = local_rank else: self.device = "cpu" if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task="is2re") def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater( 1, self.config["model_attributes"].get("otf_graph", False)) if self.config["task"]["dataset"] == "single_point_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_sampler = DistributedSampler( self.train_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=True, ) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.train_sampler, ) self.val_loader = self.test_loader = None self.val_sampler = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_sampler = DistributedSampler( self.val_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.val_sampler, ) if "test_dataset" in self.config: self.test_dataset = registry.get_dataset_class( self.config["task"]["dataset"])( self.config["test_dataset"]) self.test_sampler = DistributedSampler( self.test_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.test_loader = DataLoader( self.test_dataset, self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.test_sampler, ) else: raise NotImplementedError self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", False): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: raise NotImplementedError def load_model(self): super(EnergyTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=self.config["optim"].get("num_gpus", 1), ) if distutils.initialized(): self.model = DistributedDataParallel(self.model, device_ids=[self.device]) def train(self): self.best_val_mae = 1e9 for epoch in range(self.config["optim"]["max_epochs"]): self.train_sampler.set_epoch(epoch) self.model.train() for i, batch in enumerate(self.train_loader): # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, metrics={}, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)}) if i % self.config["cmd"]["print_every"] == 0: log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) self.scheduler.step() torch.cuda.empty_cache() if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator.task_primary_metric["is2re"]] ["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator.task_primary_metric["is2re"]]["metric"] if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) else: if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "metrics": self.metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) if self.test_loader is not None: self.validate(split="test", epoch=epoch) def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task="is2re"), {} loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm(enumerate(loader), total=len(loader)): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics def _forward(self, batch_list): output = self.model(batch_list) if output.shape[-1] == 1: output = output.view(-1) return { "energy": output, } def _compute_loss(self, out, batch_list): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): target_normed = self.normalizers["target"].norm(energy_target) else: target_normed = energy_target loss = self.criterion(out["energy"], target_normed) return loss def _compute_metrics(self, out, batch_list, evaluator, metrics={}): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): out["energy"] = self.normalizers["target"].denorm(out["energy"]) metrics = evaluator.eval( out, {"energy": energy_target}, prev_metrics=metrics, ) return metrics def predict(self, loader, results_file=None, disable_tqdm=False): assert isinstance(loader, torch.utils.data.dataloader.DataLoader) self.model.eval() if self.normalizers is not None and "target" in self.normalizers: self.normalizers["target"].to(self.device) predictions = [] for i, batch in tqdm(enumerate(loader), total=len(loader), disable=disable_tqdm): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) predictions.extend(out["energy"].tolist()) if results_file is not None: print(f"Writing results to {results_file}") # EvalAI expects a list of energies with open(results_file, "w") as resfile: json.dump(predictions, resfile) return predictions