def validate(self, split="val", epoch=None, disable_tqdm=False): if distutils.is_master(): print("### Evaluating on {}.".format(split)) if self.is_hpo: disable_tqdm = True self.model.eval() evaluator, metrics = Evaluator(task=self.name), {} rank = distutils.get_rank() loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics
def load_model(self): # Build model if distutils.is_master(): logging.info(f"Loading model: {self.config['model']}") # TODO(abhshkdz): Eventually move towards computing features on-the-fly # and remove dependence from `.edge_attr`. bond_feat_dim = None if self.config["task"]["dataset"] in [ "trajectory_lmdb", "single_point_lmdb", ]: bond_feat_dim = self.config["model_attributes"].get( "num_gaussians", 50 ) else: raise NotImplementedError loader = self.train_loader or self.val_loader or self.test_loader self.model = registry.get_model_class(self.config["model"])( loader.dataset[0].x.shape[-1] if loader and hasattr(loader.dataset[0], "x") and loader.dataset[0].x is not None else None, bond_feat_dim, self.num_targets, **self.config["model_attributes"], ).to(self.device) if distutils.is_master(): logging.info( f"Loaded {self.model.__class__.__name__} with " f"{self.model.num_params} parameters." ) if self.logger is not None: self.logger.watch(self.model) self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1 if not self.cpu else 0, ) if distutils.initialized(): self.model = DistributedDataParallel( self.model, device_ids=[self.device] )
def load_logger(self): self.logger = None if not self.is_debug and distutils.is_master(): assert (self.config["logger"] is not None), "Specify logger in config" self.logger = registry.get_logger_class(self.config["logger"])( self.config)
def predict(self, loader, results_file=None, disable_tqdm=False): if distutils.is_master() and not disable_tqdm: print("### Predicting on test.") assert isinstance(loader, torch.utils.data.dataloader.DataLoader) rank = distutils.get_rank() self.model.eval() if self.normalizers is not None and "target" in self.normalizers: self.normalizers["target"].to(self.device) predictions = {"id": [], "energy": []} for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) predictions["id"].extend([str(i) for i in batch[0].sid.tolist()]) predictions["energy"].extend(out["energy"].tolist()) self.save_results(predictions, results_file, keys=["energy"]) return predictions
def main(config): if args.distributed: distutils.setup(config) try: setup_imports() trainer = registry.get_trainer_class(config.get("trainer", "simple"))( task=config["task"], model=config["model"], dataset=config["dataset"], optimizer=config["optim"], identifier=config["identifier"], run_dir=config.get("run_dir", "./"), is_debug=config.get("is_debug", False), is_vis=config.get("is_vis", False), print_every=config.get("print_every", 10), seed=config.get("seed", 0), logger=config.get("logger", "tensorboard"), local_rank=config["local_rank"], amp=config.get("amp", False), cpu=config.get("cpu", False), ) if config["checkpoint"] is not None: trainer.load_pretrained(config["checkpoint"]) start_time = time.time() if config["mode"] == "train": trainer.train() elif config["mode"] == "predict": assert ( trainer.test_loader is not None), "Test dataset is required for making predictions" assert config["checkpoint"] results_file = "predictions" trainer.predict( trainer.test_loader, results_file=results_file, disable_tqdm=False, ) elif config["mode"] == "run-relaxations": assert isinstance( trainer, ForcesTrainer ), "Relaxations are only possible for ForcesTrainer" assert (trainer.relax_dataset is not None ), "Relax dataset is required for making predictions" assert config["checkpoint"] trainer.run_relaxations() distutils.synchronize() if distutils.is_master(): print("Total time taken = ", time.time() - start_time) finally: if args.distributed: distutils.cleanup()
def predict(self, loader, per_image=True, results_file=None, disable_tqdm=False): if distutils.is_master() and not disable_tqdm: logging.info("Predicting on test.") assert isinstance( loader, ( torch.utils.data.dataloader.DataLoader, torch_geometric.data.Batch, ), ) rank = distutils.get_rank() if isinstance(loader, torch_geometric.data.Batch): loader = [[loader]] self.model.eval() if self.ema: self.ema.store() self.ema.copy_to() if self.normalizers is not None and "target" in self.normalizers: self.normalizers["target"].to(self.device) predictions = {"id": [], "energy": []} for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) if per_image: predictions["id"].extend( [str(i) for i in batch[0].sid.tolist()]) predictions["energy"].extend(out["energy"].tolist()) else: predictions["energy"] = out["energy"].detach() return predictions self.save_results(predictions, results_file, keys=["energy"]) if self.ema: self.ema.restore() return predictions
def save( self, metrics=None, checkpoint_file="checkpoint.pt", training_state=True, ): if not self.is_debug and distutils.is_master(): if training_state: save_checkpoint( { "epoch": self.epoch, "step": self.step, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.scheduler.state_dict() if self.scheduler.scheduler_type != "Null" else None, "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": metrics, "ema": self.ema.state_dict() if self.ema else None, "amp": self.scaler.state_dict() if self.scaler else None, }, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) else: if self.ema: self.ema.store() self.ema.copy_to() save_checkpoint( { "state_dict": self.model.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, checkpoint_dir=self.config["cmd"]["checkpoint_dir"], checkpoint_file=checkpoint_file, ) if self.ema: self.ema.restore()
def save_results(self, predictions, results_file, keys): if results_file is None: return results_file_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{distutils.get_rank()}.npz", ) np.savez_compressed( results_file_path, ids=predictions["id"], **{key: predictions[key] for key in keys}, ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) for key in keys: gather_results[key].extend(rank_results[key]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] for k in keys: if k == "forces": gather_results[k] = np.concatenate( np.array(gather_results[k])[idx] ) elif k == "chunk_idx": gather_results[k] = np.cumsum( np.array(gather_results[k])[idx] )[:-1] else: gather_results[k] = np.array(gather_results[k])[idx] logging.info(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results)
def save(self, epoch, metrics): if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], )
def __call__(self, config): setup_logging() self.config = copy.deepcopy(config) if args.distributed: distutils.setup(config) try: setup_imports() self.trainer = registry.get_trainer_class( config.get("trainer", "simple"))( task=config["task"], model=config["model"], dataset=config["dataset"], optimizer=config["optim"], identifier=config["identifier"], timestamp_id=config.get("timestamp_id", None), run_dir=config.get("run_dir", "./"), is_debug=config.get("is_debug", False), is_vis=config.get("is_vis", False), print_every=config.get("print_every", 10), seed=config.get("seed", 0), logger=config.get("logger", "tensorboard"), local_rank=config["local_rank"], amp=config.get("amp", False), cpu=config.get("cpu", False), slurm=config.get("slurm", {}), ) self.task = registry.get_task_class(config["mode"])(self.config) self.task.setup(self.trainer) start_time = time.time() self.task.run() distutils.synchronize() if distutils.is_master(): logging.info(f"Total time taken: {time.time() - start_time}") finally: if args.distributed: distutils.cleanup()
def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, is_hpo=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, cpu=False, name="base_trainer", ): self.name = name self.cpu = cpu self.start_step = 0 if torch.cuda.is_available() and not self.cpu: self.device = local_rank else: self.device = "cpu" self.cpu = True # handle case when `--cpu` isn't specified # but there are no gpu devices available if run_dir is None: run_dir = os.getcwd() timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp( timestamp.int()).strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) try: commit_hash = (subprocess.check_output([ "git", "-C", ocpmodels.__path__[0], "describe", "--always", ]).strip().decode("ascii")) # catch instances where code is not being run from a git repo except Exception: commit_hash = None self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "gpus": distutils.get_world_size() if not self.cpu else 0, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "commit": commit_hash, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master() and not is_hpo: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis self.is_hpo = is_hpo if self.is_hpo: # sets the hpo checkpoint frequency # default is no checkpointing self.hpo_checkpoint_every = self.config["optim"].get( "checkpoint_every", -1) if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name)
def train(self): self.best_val_mae = 1e9 eval_every = self.config["optim"].get("eval_every", -1) iters = 0 self.metrics = {} for epoch in range(self.config["optim"]["max_epochs"]): self.model.train() for i, batch in enumerate(self.train_loader): # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, self.metrics, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)}) if i % self.config["cmd"]["print_every"] == 0: log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) self.metrics = {} if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) iters += 1 # Evaluate on val set every `eval_every` iterations. if eval_every != -1 and iters % eval_every == 0: if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator. task_primary_metric["s2ef"]]["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator. task_primary_metric["s2ef"]]["metric"] if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch + (i + 1) / len(self.train_loader), "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, }, self.config["cmd"]["checkpoint_dir"], ) self.scheduler.step() torch.cuda.empty_cache() if eval_every == -1: if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator.task_primary_metric["s2ef"]] ["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator. task_primary_metric["s2ef"]]["metric"] if not self.is_debug and distutils.is_master(): save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, }, self.config["cmd"]["checkpoint_dir"], ) if self.test_loader is not None: self.validate(split="test", epoch=epoch) if ("relaxation_dir" in self.config["task"] and self.config["task"].get("ml_relax", "end") == "train"): self.validate_relaxation( split="val", epoch=epoch, ) if ("relaxation_dir" in self.config["task"] and self.config["task"].get("ml_relax", "end") == "end"): self.validate_relaxation( split="val", epoch=epoch, )
def run_relaxations(self, split="val", epoch=None): print("### Running ML-relaxations") self.model.eval() evaluator, metrics = Evaluator(task="is2rs"), {} if hasattr(self.relax_dataset[0], "pos_relaxed") and hasattr( self.relax_dataset[0], "y_relaxed"): split = "val" else: split = "test" ids = [] relaxed_positions = [] for i, batch in tqdm(enumerate(self.relax_loader), total=len(self.relax_loader)): relaxed_batch = ml_relax( batch=batch, model=self, steps=self.config["task"].get("relaxation_steps", 200), fmax=self.config["task"].get("relaxation_fmax", 0.0), relax_opt=self.config["task"]["relax_opt"], device=self.device, transform=None, ) if self.config["task"].get("write_pos", False): systemids = [str(i) for i in relaxed_batch.sid.tolist()] natoms = relaxed_batch.natoms.tolist() positions = torch.split(relaxed_batch.pos, natoms) batch_relaxed_positions = [pos.tolist() for pos in positions] relaxed_positions += batch_relaxed_positions ids += systemids if split == "val": mask = relaxed_batch.fixed == 0 s_idx = 0 natoms_free = [] for natoms in relaxed_batch.natoms: natoms_free.append( torch.sum(mask[s_idx:s_idx + natoms]).item()) s_idx += natoms target = { "energy": relaxed_batch.y_relaxed, "positions": relaxed_batch.pos_relaxed[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } prediction = { "energy": relaxed_batch.y, "positions": relaxed_batch.pos[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } metrics = evaluator.eval(prediction, target, metrics) if self.config["task"].get("write_pos", False): rank = distutils.get_rank() pos_filename = os.path.join(self.config["cmd"]["results_dir"], f"relaxed_pos_{rank}.npz") np.savez_compressed( pos_filename, ids=ids, pos=np.array(relaxed_positions, dtype=object), ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"relaxed_pos_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) gather_results["pos"].extend(rank_results["pos"]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] gather_results["pos"] = np.array(gather_results["pos"], dtype=object)[idx] print(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results) if split == "val": aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = ( aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics # Make plots. log_dict = {k: metrics[k]["metric"] for k in metrics} if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) if distutils.is_master(): print(metrics)
def predict(self, data_loader, per_image=True, results_file=None, disable_tqdm=True): if distutils.is_master() and not disable_tqdm: print("### Predicting on test.") assert isinstance( data_loader, ( torch.utils.data.dataloader.DataLoader, torch_geometric.data.Batch, ), ) rank = distutils.get_rank() if isinstance(data_loader, torch_geometric.data.Batch): data_loader = [[data_loader]] self.model.eval() if self.normalizers is not None and "target" in self.normalizers: self.normalizers["target"].to(self.device) self.normalizers["grad_target"].to(self.device) predictions = {"id": [], "energy": [], "forces": []} for i, batch_list in tqdm( enumerate(data_loader), total=len(data_loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch_list) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) out["forces"] = self.normalizers["grad_target"].denorm( out["forces"]) if per_image: atoms_sum = 0 systemids = [ str(i) + "_" + str(j) for i, j in zip( batch_list[0].sid.tolist(), batch_list[0].fid.tolist()) ] predictions["id"].extend(systemids) predictions["energy"].extend(out["energy"].to( torch.float16).tolist()) batch_natoms = torch.cat( [batch.natoms for batch in batch_list]) batch_fixed = torch.cat([batch.fixed for batch in batch_list]) for natoms in batch_natoms: forces = (out["forces"][atoms_sum:natoms + atoms_sum].cpu().detach().to( torch.float16).numpy()) # evalAI only requires forces on free atoms if results_file is not None: _free_atoms = (batch_fixed[atoms_sum:natoms + atoms_sum] == 0).tolist() forces = forces[_free_atoms] atoms_sum += natoms predictions["forces"].append(forces) else: predictions["energy"] = out["energy"].detach() predictions["forces"] = out["forces"].detach() return predictions predictions["forces"] = np.array(predictions["forces"], dtype=object) predictions["energy"] = np.array(predictions["energy"]) predictions["id"] = np.array(predictions["id"]) self.save_results(predictions, results_file, keys=["energy", "forces"]) return predictions
def __init__( self, task, model, dataset, optimizer, identifier, normalizer=None, timestamp_id=None, run_dir=None, is_debug=False, is_vis=False, is_hpo=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, cpu=False, name="base_trainer", slurm={}, ): self.name = name self.cpu = cpu self.epoch = 0 self.step = 0 if torch.cuda.is_available() and not self.cpu: self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") self.cpu = True # handle case when `--cpu` isn't specified # but there are no gpu devices available if run_dir is None: run_dir = os.getcwd() if timestamp_id is None: timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device ) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp( timestamp.int() ).strftime("%Y-%m-%d-%H-%M-%S") if identifier: self.timestamp_id = f"{timestamp}-{identifier}" else: self.timestamp_id = timestamp else: self.timestamp_id = timestamp_id try: commit_hash = ( subprocess.check_output( [ "git", "-C", ocpmodels.__path__[0], "describe", "--always", ] ) .strip() .decode("ascii") ) # catch instances where code is not being run from a git repo except Exception: commit_hash = None self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "gpus": distutils.get_world_size() if not self.cpu else 0, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp_id": self.timestamp_id, "commit": commit_hash, "checkpoint_dir": os.path.join( run_dir, "checkpoints", self.timestamp_id ), "results_dir": os.path.join( run_dir, "results", self.timestamp_id ), "logs_dir": os.path.join( run_dir, "logs", logger, self.timestamp_id ), }, "slurm": slurm, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if "SLURM_JOB_ID" in os.environ and "folder" in self.config["slurm"]: self.config["slurm"]["job_id"] = os.environ["SLURM_JOB_ID"] self.config["slurm"]["folder"] = self.config["slurm"][ "folder" ].replace("%j", self.config["slurm"]["job_id"]) if isinstance(dataset, list): if len(dataset) > 0: self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] elif isinstance(dataset, dict): self.config["dataset"] = dataset.get("train", None) self.config["val_dataset"] = dataset.get("val", None) self.config["test_dataset"] = dataset.get("test", None) else: self.config["dataset"] = dataset self.normalizer = normalizer # This supports the legacy way of providing norm parameters in dataset if self.config.get("dataset", None) is not None and normalizer is None: self.normalizer = self.config["dataset"] if not is_debug and distutils.is_master() and not is_hpo: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis self.is_hpo = is_hpo if self.is_hpo: # conditional import is necessary for checkpointing from ray import tune from ocpmodels.common.hpo_utils import tune_reporter # sets the hpo checkpoint frequency # default is no checkpointing self.hpo_checkpoint_every = self.config["optim"].get( "checkpoint_every", -1 ) if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name)
def train(self, disable_eval_tqdm=False): eval_every = self.config["optim"].get("eval_every", None) if eval_every is None: eval_every = len(self.train_loader) checkpoint_every = self.config["optim"].get("checkpoint_every", eval_every) primary_metric = self.config["task"].get( "primary_metric", self.evaluator.task_primary_metric[self.name]) self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0 self.metrics = {} # Calculate start_epoch from step instead of loading the epoch number # to prevent inconsistencies due to different batch size in checkpoint. start_epoch = self.step // len(self.train_loader) for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): self.train_sampler.set_epoch(epoch_int) skip_steps = self.step % len(self.train_loader) train_loader_iter = iter(self.train_loader) for i in range(skip_steps, len(self.train_loader)): self.epoch = epoch_int + (i + 1) / len(self.train_loader) self.step = epoch_int * len(self.train_loader) + i + 1 self.model.train() # Get a batch. batch = next(train_loader_iter) if self.config["optim"]["optimizer"] == "LBFGS": def closure(): self.optimizer.zero_grad() with torch.cuda.amp.autocast( enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss.backward() return loss self.optimizer.step(closure) self.optimizer.zero_grad() with torch.cuda.amp.autocast( enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) else: # Forward, loss, backward. with torch.cuda.amp.autocast( enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, self.metrics, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Log metrics. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update({ "lr": self.scheduler.get_lr(), "epoch": self.epoch, "step": self.step, }) if (self.step % self.config["cmd"]["print_every"] == 0 and distutils.is_master() and not self.is_hpo): log_str = [ "{}: {:.2e}".format(k, v) for k, v in log_dict.items() ] logging.info(", ".join(log_str)) self.metrics = {} if self.logger is not None: self.logger.log( log_dict, step=self.step, split="train", ) if checkpoint_every != -1 and self.step % checkpoint_every == 0: self.save(checkpoint_file="checkpoint.pt", training_state=True) # Evaluate on val set every `eval_every` iterations. if self.step % eval_every == 0: if self.val_loader is not None: val_metrics = self.validate( split="val", disable_tqdm=disable_eval_tqdm, ) self.update_best( primary_metric, val_metrics, disable_eval_tqdm=disable_eval_tqdm, ) if self.is_hpo: self.hpo_update( self.epoch, self.step, self.metrics, val_metrics, ) if self.config["task"].get("eval_relaxations", False): if "relax_dataset" not in self.config["task"]: logging.warning( "Cannot evaluate relaxations, relax_dataset not specified" ) else: self.run_relaxations() if self.config["optim"].get("print_loss_and_lr", False): print( "epoch: " + str(self.epoch) + ", \tstep: " + str(self.step) + ", \tloss: " + str(loss.detach().item()) + ", \tlr: " + str(self.scheduler.get_lr()) + ", \tval: " + str(val_metrics["loss"]["total"]) ) if self.step % eval_every == 0 and self.val_loader is not None else print( "epoch: " + str(self.epoch) + ", \tstep: " + str(self.step) + ", \tloss: " + str(loss.detach().item()) + ", \tlr: " + str(self.scheduler.get_lr())) if self.scheduler.scheduler_type == "ReduceLROnPlateau": if (self.step % eval_every == 0 and self.config["optim"].get( "scheduler_loss", None) == "train"): self.scheduler.step(metrics=loss.detach().item(), ) elif self.step % eval_every == 0 and self.val_loader is not None: self.scheduler.step( metrics=val_metrics[primary_metric]["metric"], ) else: self.scheduler.step() break_below_lr = (self.config["optim"].get( "break_below_lr", None) is not None) and ( self.scheduler.get_lr() < self.config["optim"]["break_below_lr"]) if break_below_lr: break if break_below_lr: break torch.cuda.empty_cache() if checkpoint_every == -1: self.save(checkpoint_file="checkpoint.pt", training_state=True) self.train_dataset.close_db() if "val_dataset" in self.config: self.val_dataset.close_db() if "test_dataset" in self.config: self.test_dataset.close_db()
def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, name="base_trainer", ): self.name = name if torch.cuda.is_available(): self.device = local_rank else: self.device = "cpu" if run_dir is None: run_dir = os.getcwd() run_dir = Path(run_dir) timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp(timestamp).strftime( "%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": str(run_dir / "checkpoints" / timestamp), "results_dir": str(run_dir / "results" / timestamp), "logs_dir": str(run_dir / "logs" / logger / timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name)
def train(self): self.best_val_mae = 1e9 start_epoch = self.start_step // len(self.train_loader) for epoch in range(start_epoch, self.config["optim"]["max_epochs"]): self.train_sampler.set_epoch(epoch) self.model.train() skip_steps = 0 if epoch == start_epoch and start_epoch > 0: skip_steps = start_epoch % len(self.train_loader) train_loader_iter = iter(self.train_loader) for i in range(skip_steps, len(self.train_loader)): batch = next(train_loader_iter) # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, metrics={}, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)}) if (i % self.config["cmd"]["print_every"] == 0 and distutils.is_master()): log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) if self.update_lr_on_step: self.scheduler.step() if not self.update_lr_on_step: self.scheduler.step() torch.cuda.empty_cache() if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator.task_primary_metric[self.name]] ["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator.task_primary_metric[ self.name]]["metric"] current_step = (epoch + 1) * len(self.train_loader) self.save(epoch + 1, current_step, val_metrics) if self.test_loader is not None: self.predict( self.test_loader, results_file="predictions", disable_tqdm=False, ) else: current_step = (epoch + 1) * len(self.train_loader) self.save(epoch + 1, current_step, self.metrics) self.train_dataset.close_db() if "val_dataset" in self.config: self.val_dataset.close_db() if "test_dataset" in self.config: self.test_dataset.close_db()
def train(self): eval_every = self.config["optim"].get("eval_every", -1) primary_metric = self.config["task"].get( "primary_metric", self.evaluator.task_primary_metric[self.name]) self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0 iters = 0 self.metrics = {} for epoch in range(self.config["optim"]["max_epochs"]): self.model.train() for i, batch in enumerate(self.train_loader): # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, self.metrics, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)}) if (i % self.config["cmd"]["print_every"] == 0 and distutils.is_master()): log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) self.metrics = {} if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) iters += 1 # Evaluate on val set every `eval_every` iterations. if eval_every != -1 and iters % eval_every == 0: if self.val_loader is not None: val_metrics = self.validate( split="val", epoch=epoch - 1 + (i + 1) / len(self.train_loader), ) if ("mae" in primary_metric and val_metrics[primary_metric]["metric"] < self.best_val_metric) or ( val_metrics[primary_metric]["metric"] > self.best_val_metric): self.best_val_metric = val_metrics[primary_metric][ "metric"] current_epoch = epoch + (i + 1) / len( self.train_loader) self.save(current_epoch, val_metrics) if self.test_loader is not None: self.predict( self.test_loader, results_file="predictions", disable_tqdm=False, ) self.scheduler.step() torch.cuda.empty_cache() if eval_every == -1: if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if ("mae" in primary_metric and val_metrics[primary_metric]["metric"] < self.best_val_metric) or ( val_metrics[primary_metric]["metric"] > self.best_val_metric): self.best_val_metric = val_metrics[primary_metric][ "metric"] self.save(epoch + 1, val_metrics) if self.test_loader is not None: self.predict( self.test_loader, results_file="predictions", disable_tqdm=False, ) else: self.save(epoch + 1, self.metrics)
def train(self): eval_every = self.config["optim"].get("eval_every", len(self.train_loader)) primary_metric = self.config["task"].get( "primary_metric", self.evaluator.task_primary_metric[self.name]) self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0 iters = 0 self.metrics = {} start_epoch = self.start_step // len(self.train_loader) for epoch in range(start_epoch, self.config["optim"]["max_epochs"]): self.train_sampler.set_epoch(epoch) skip_steps = 0 if epoch == start_epoch and start_epoch > 0: skip_steps = start_epoch % len(self.train_loader) train_loader_iter = iter(self.train_loader) for i in range(skip_steps, len(self.train_loader)): self.model.train() current_epoch = epoch + (i + 1) / len(self.train_loader) current_step = epoch * len(self.train_loader) + (i + 1) # Get a batch. batch = next(train_loader_iter) # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, self.metrics, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Log metrics. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update({ "lr": self.scheduler.get_lr(), "epoch": current_epoch, "step": current_step, }) if (current_step % self.config["cmd"]["print_every"] == 0 and distutils.is_master() and not self.is_hpo): log_str = [ "{}: {:.2e}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) self.metrics = {} if self.logger is not None: self.logger.log( log_dict, step=current_step, split="train", ) iters += 1 # Evaluate on val set every `eval_every` iterations. if iters % eval_every == 0: if self.val_loader is not None: val_metrics = self.validate( split="val", epoch=epoch - 1 + (i + 1) / len(self.train_loader), ) if ("mae" in primary_metric and val_metrics[primary_metric]["metric"] < self.best_val_metric) or ( val_metrics[primary_metric]["metric"] > self.best_val_metric): self.best_val_metric = val_metrics[primary_metric][ "metric"] self.save(current_epoch, current_step, val_metrics) if self.test_loader is not None: self.predict( self.test_loader, results_file="predictions", disable_tqdm=False, ) if self.is_hpo: self.hpo_update( current_epoch, current_step, self.metrics, val_metrics, ) else: self.save(current_epoch, current_step, self.metrics) if self.scheduler.scheduler_type == "ReduceLROnPlateau": if iters % eval_every == 0: self.scheduler.step( metrics=val_metrics[primary_metric]["metric"], ) else: self.scheduler.step() torch.cuda.empty_cache() self.train_dataset.close_db() if "val_dataset" in self.config: self.val_dataset.close_db() if "test_dataset" in self.config: self.test_dataset.close_db()
def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater( 1 if not self.cpu else 0, self.config["model_attributes"].get("otf_graph", False), ) if self.config["task"]["dataset"] == "trajectory_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], shuffle=True, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) self.val_loader = self.test_loader = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) if "test_dataset" in self.config: self.test_dataset = registry.get_dataset_class( self.config["task"]["dataset"])( self.config["test_dataset"]) self.test_loader = DataLoader( self.test_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) if "relax_dataset" in self.config["task"]: assert os.path.isfile( self.config["task"]["relax_dataset"]["src"]) self.relax_dataset = registry.get_dataset_class( "single_point_lmdb")(self.config["task"]["relax_dataset"]) self.relax_sampler = DistributedSampler( self.relax_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.relax_loader = DataLoader( self.relax_dataset, batch_size=self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.relax_sampler, ) else: self.dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) ( self.train_loader, self.val_loader, self.test_loader, ) = self.dataset.get_dataloaders( batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, ) self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", False): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: self.normalizers["target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if self.config["model_attributes"].get("regress_forces", True): if self.config["dataset"].get("normalize_labels", False): if "grad_target_mean" in self.config["dataset"]: self.normalizers["grad_target"] = Normalizer( mean=self.config["dataset"]["grad_target_mean"], std=self.config["dataset"]["grad_target_std"], device=self.device, ) else: self.normalizers["grad_target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) self.normalizers["grad_target"].mean.fill_(0) if (self.is_vis and self.config["task"]["dataset"] != "qm9" and distutils.is_master()): # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots)
def save(self, epoch, step, metrics): if not self.is_debug and distutils.is_master() and not self.is_hpo: save_checkpoint( self.save_state(epoch, step, metrics), self.config["cmd"]["checkpoint_dir"], )
def train(self, disable_eval_tqdm=False): eval_every = self.config["optim"].get("eval_every", len(self.train_loader)) primary_metric = self.config["task"].get( "primary_metric", self.evaluator.task_primary_metric[self.name]) self.best_val_mae = 1e9 # Calculate start_epoch from step instead of loading the epoch number # to prevent inconsistencies due to different batch size in checkpoint. start_epoch = self.step // len(self.train_loader) for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): self.train_sampler.set_epoch(epoch_int) skip_steps = self.step % len(self.train_loader) train_loader_iter = iter(self.train_loader) for i in range(skip_steps, len(self.train_loader)): self.epoch = epoch_int + (i + 1) / len(self.train_loader) self.step = epoch_int * len(self.train_loader) + i + 1 self.model.train() # Get a batch. batch = next(train_loader_iter) # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, metrics={}, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Log metrics. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update({ "lr": self.scheduler.get_lr(), "epoch": self.epoch, "step": self.step, }) if (self.step % self.config["cmd"]["print_every"] == 0 and distutils.is_master() and not self.is_hpo): log_str = [ "{}: {:.2e}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) self.metrics = {} if self.logger is not None: self.logger.log( log_dict, step=self.step, split="train", ) # Evaluate on val set after every `eval_every` iterations. if self.step % eval_every == 0: self.save(checkpoint_file="checkpoint.pt", training_state=True) if self.val_loader is not None: val_metrics = self.validate( split="val", disable_tqdm=disable_eval_tqdm, ) if (val_metrics[self.evaluator.task_primary_metric[ self.name]]["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator.task_primary_metric[ self.name]]["metric"] self.save( metrics=val_metrics, checkpoint_file="best_checkpoint.pt", training_state=False, ) if self.test_loader is not None: self.predict( self.test_loader, results_file="predictions", disable_tqdm=False, ) if self.is_hpo: self.hpo_update( self.epoch, self.step, self.metrics, val_metrics, ) else: self.save(self.epoch, self.step, self.metrics) if self.scheduler.scheduler_type == "ReduceLROnPlateau": if self.step % eval_every == 0: self.scheduler.step( metrics=val_metrics[primary_metric]["metric"], ) else: self.scheduler.step() torch.cuda.empty_cache() self.train_dataset.close_db() if "val_dataset" in self.config: self.val_dataset.close_db() if "test_dataset" in self.config: self.test_dataset.close_db()
def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis if torch.cuda.is_available(): self.device = local_rank else: self.device = "cpu" if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task="s2ef")