def load_model(self): super(ForcesTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.output_device, num_gpus=self.config["optim"].get("num_gpus", 1), ) self.model.to(self.device)
def main(args): with open(args.config_yml) as f: yml_file = yaml.safe_load(f) model_name = yml_file["model"].pop('name', None) print(f"#### Loading model: {model_name}") checkpoint_path = yml_file['checkpoint']['src'] checkpoint = modify_checkpoint(checkpoint_path) model = DimeNetPlusPlusWrap(**yml_file["model"]) model.load_state_dict(checkpoint) model = OCPDataParallel(model, output_device=0, num_gpus=1) if yml_file['dataset']['src']: batch = lmdb_to_batch(yml_file['dataset']['src']) else: atoms = ase.io.read( "../tests/models/atoms.json", index=0, format="json", ) a2g = AtomsToGraphs( max_neigh=12, radius=6, dummy_distance=7, dummy_index=-1, r_energy=True, r_forces=True, r_distances=True, ) batch = Batch.from_data_list(a2g.convert_all([atoms])) output = model(batch) viz = model_viz(checkpoint_path) if yml_file["task"]["computation_graph"]: print("#### Plotting computation graph") viz.computation_graph(model, batch) if yml_file["task"]["t-sne_viz"]: print("#### Plotting t-sne") emb_weight = checkpoint["emb.emb.weight"].cpu().numpy() viz.tsne_viz_emb(emb_weight) if yml_file["task"]["pca_t-sne_viz"]: print("#### Plotting PCA reduced t-sne") emb_weight = checkpoint["emb.emb.weight"].cpu().numpy() res = viz.pca(emb_weight, n=50) viz.tsne_viz_emb(res) if yml_file["task"]["raw_weights"]: print("#### Plotting raw emb weights") emb_weight = checkpoint["emb.emb.weight"].cpu().numpy() viz.raw_weights_viz(emb_weight) if yml_file["task"]["is2rs_plot"]: print("#### Plotting is2rs comparison") viz.create_is2rs_plots(batch, output)
def load_model(self): super(DistributedForcesTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1, ) self.model = DistributedDataParallel(self.model, device_ids=[self.device], find_unused_parameters=True)
def load_model(self): super(EnergyTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=self.config["optim"].get("num_gpus", 1), ) if distutils.initialized(): self.model = DistributedDataParallel(self.model, device_ids=[self.device])
def load_model(self): super(DistributedEnergyTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=self.config["optim"].get("num_gpus", 1), ) self.model = DistributedDataParallel(self.model, device_ids=[self.device], find_unused_parameters=True)
def load_model(self): # Build model if distutils.is_master(): logging.info(f"Loading model: {self.config['model']}") # TODO(abhshkdz): Eventually move towards computing features on-the-fly # and remove dependence from `.edge_attr`. bond_feat_dim = None if self.config["task"]["dataset"] in [ "trajectory_lmdb", "single_point_lmdb", ]: bond_feat_dim = self.config["model_attributes"].get( "num_gaussians", 50 ) else: raise NotImplementedError loader = self.train_loader or self.val_loader or self.test_loader self.model = registry.get_model_class(self.config["model"])( loader.dataset[0].x.shape[-1] if loader and hasattr(loader.dataset[0], "x") and loader.dataset[0].x is not None else None, bond_feat_dim, self.num_targets, **self.config["model_attributes"], ).to(self.device) if distutils.is_master(): logging.info( f"Loaded {self.model.__class__.__name__} with " f"{self.model.num_params} parameters." ) if self.logger is not None: self.logger.watch(self.model) self.model = OCPDataParallel( self.model, output_device=self.device, num_gpus=1 if not self.cpu else 0, ) if distutils.initialized(): self.model = DistributedDataParallel( self.model, device_ids=[self.device] )
class ForcesTrainer(BaseTrainer): def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, } if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] else: self.config["dataset"] = dataset if not is_debug: os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis if torch.cuda.is_available(): device_ids = list(range(torch.cuda.device_count())) self.output_device = self.config["optim"].get( "output_device", device_ids[0]) self.device = f"cuda:{self.output_device}" else: self.device = "cpu" print(yaml.dump(self.config, default_flow_style=False)) self.load() def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater(self.config["optim"].get( "num_gpus", 1)) if self.config["task"]["dataset"] == "trajectory_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], shuffle=True, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], ) self.val_loader = self.test_loader = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], ) else: self.dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) ( self.train_loader, self.val_loader, self.test_loader, ) = self.dataset.get_dataloaders( batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, ) self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: self.normalizers["target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if "grad_input" in self.config["task"]: if self.config["dataset"].get("normalize_labels", True): if "target_mean" in self.config["dataset"]: self.normalizers["grad_target"] = Normalizer( mean=self.config["dataset"]["grad_target_mean"], std=self.config["dataset"]["grad_target_std"], device=self.device, ) else: self.normalizers["grad_target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) self.normalizers["grad_target"].mean.fill_(0) if self.is_vis and self.config["task"]["dataset"] != "qm9": # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots) def load_model(self): super(ForcesTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.output_device, num_gpus=self.config["optim"].get("num_gpus", 1), ) self.model.to(self.device) # Takes in a new data source and generates predictions on it. def predict(self, dataset, batch_size=32): if isinstance(dataset, dict): if self.config["task"]["dataset"] == "trajectory_lmdb": print("### Generating predictions on {}.".format( dataset["src"])) else: print("### Generating predictions on {}.".format( dataset["src"] + dataset["traj"])) dataset = registry.get_dataset_class( self.config["task"]["dataset"])(dataset) data_loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=self.parallel_collater, ) elif isinstance(dataset, torch_geometric.data.Batch): data_loader = [[dataset]] else: raise NotImplementedError self.model.eval() predictions = {"energy": [], "forces": []} for i, batch_list in enumerate(data_loader): out, _ = self._forward(batch_list, compute_metrics=False) if self.normalizers is not None and "target" in self.normalizers: out["output"] = self.normalizers["target"].denorm( out["output"]) out["force_output"] = self.normalizers["grad_target"].denorm( out["force_output"]) atoms_sum = 0 predictions["energy"].extend(out["output"].tolist()) batch_natoms = torch.cat([batch.natoms for batch in batch_list]) for natoms in batch_natoms: predictions["forces"].append( out["force_output"][atoms_sum:natoms + atoms_sum].cpu().detach().numpy()) atoms_sum += natoms return predictions def train(self): for epoch in range(self.config["optim"]["max_epochs"]): self.model.train() for i, batch in enumerate(self.train_loader): # Forward, loss, backward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) self._backward(loss) # Update meter. meter_update_dict = { "epoch": epoch + (i + 1) / len(self.train_loader), "loss": loss.item(), } meter_update_dict.update(metrics) self.meter.update(meter_update_dict) # Make plots. if self.logger is not None: self.logger.log( meter_update_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) # Print metrics. if i % self.config["cmd"]["print_every"] == 0: print(self.meter) self.scheduler.step() torch.cuda.empty_cache() if self.val_loader is not None: self.validate(split="val", epoch=epoch) if self.test_loader is not None: self.validate(split="test", epoch=epoch) if ("relaxation_dir" in self.config["task"] and self.config["task"].get("ml_relax", "end") == "train"): self.validate_relaxation( split="val", epoch=epoch, ) if not self.is_debug: save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, }, self.config["cmd"]["checkpoint_dir"], ) if ("relaxation_dir" in self.config["task"] and self.config["task"].get("ml_relax", "end") == "end"): self.validate_relaxation( split="val", epoch=epoch, ) def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() meter = Meter(split=split) loader = self.val_loader if split == "val" else self.test_loader for i, batch in enumerate(loader): # Forward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) # Update meter. meter_update_dict = {"loss": loss.item()} meter_update_dict.update(metrics) meter.update(meter_update_dict) # Make plots. if self.logger is not None and epoch is not None: log_dict = meter.get_scalar_dict() log_dict.update({"epoch": epoch + 1}) self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(meter) def validate_relaxation(self, split="val", epoch=None): print("### Evaluating ML-relaxation") self.model.eval() metrics = {} meter = Meter(split=split) mae_energy, mae_structure = relax_eval( trainer=self, traj_dir=self.config["task"]["relaxation_dir"], metric=self.config["task"]["metric"], steps=self.config["task"].get("relaxation_steps", 300), fmax=self.config["task"].get("relaxation_fmax", 0.01), results_dir=self.config["cmd"]["results_dir"], ) metrics["relaxed_energy/{}".format( self.config["task"]["metric"])] = mae_energy metrics["relaxed_structure/{}".format( self.config["task"]["metric"])] = mae_structure meter.update(metrics) # Make plots. if self.logger is not None and epoch is not None: log_dict = meter.get_scalar_dict() log_dict.update({"epoch": epoch + 1}) self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(meter) return mae_energy, mae_structure def _forward(self, batch_list, compute_metrics=True): out = {} # forward pass. if self.config["model_attributes"].get("regress_forces", True): output, output_forces = self.model(batch_list) else: output = self.model(batch_list) if output.shape[-1] == 1: output = output.view(-1) out["output"] = output force_output = None if self.config["model_attributes"].get("regress_forces", True): out["force_output"] = output_forces force_output = output_forces if not compute_metrics: return out, None metrics = {} energy_target = torch.cat( [batch.y.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): errors = eval(self.config["task"]["metric"])( self.normalizers["target"].denorm(output).cpu(), energy_target.cpu(), ).view(-1) else: errors = eval(self.config["task"]["metric"])( output.cpu(), energy_target.cpu()).view(-1) if ("label_index" in self.config["task"] and self.config["task"]["label_index"] is not False): # TODO(abhshkdz): Get rid of this edge case for QM9. # This is only because QM9 has multiple targets and we can either # jointly predict all of them or one particular target. metrics["{}/{}".format( self.config["task"]["labels"][self.config["task"] ["label_index"]], self.config["task"]["metric"], )] = errors[0] else: for i, label in enumerate(self.config["task"]["labels"]): metrics["{}/{}".format( label, self.config["task"]["metric"])] = errors[i] if "grad_input" in self.config["task"]: force_pred = force_output force_target = torch.cat( [batch.force.to(self.device) for batch in batch_list], dim=0) if self.config["task"].get("eval_on_free_atoms", True): fixed = torch.cat([batch.fixed for batch in batch_list]) mask = fixed == 0 force_pred = force_pred[mask] force_target = force_target[mask] if self.config["dataset"].get("normalize_labels", True): grad_input_errors = eval(self.config["task"]["metric"])( self.normalizers["grad_target"].denorm(force_pred).cpu(), force_target.cpu(), ) else: grad_input_errors = eval(self.config["task"]["metric"])( force_pred.cpu(), force_target.cpu()) metrics["force_x/{}".format( self.config["task"]["metric"])] = grad_input_errors[0] metrics["force_y/{}".format( self.config["task"]["metric"])] = grad_input_errors[1] metrics["force_z/{}".format( self.config["task"]["metric"])] = grad_input_errors[2] return out, metrics def _compute_loss(self, out, batch_list): loss = [] energy_target = torch.cat( [batch.y.to(self.device) for batch in batch_list], dim=0) force_target = torch.cat( [batch.force.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): target_normed = self.normalizers["target"].norm(energy_target) else: target_normed = energy_target loss.append(self.criterion(out["output"], target_normed)) # TODO(abhshkdz): Test support for gradients wrt input. # TODO(abhshkdz): Make this general; remove dependence on `.forces`. if "grad_input" in self.config["task"]: if self.config["dataset"].get("normalize_labels", True): grad_target_normed = self.normalizers["grad_target"].norm( force_target) else: grad_target_normed = force_target # Force coefficient = 30 has been working well for us. force_mult = self.config["optim"].get("force_coefficient", 30) if self.config["task"].get("train_on_free_atoms", False): fixed = torch.cat([batch.fixed for batch in batch_list]) mask = fixed == 0 loss.append(force_mult * self.criterion( out["force_output"][mask], grad_target_normed[mask])) else: loss.append( force_mult * self.criterion(out["force_output"], grad_target_normed)) # Sanity check to make sure the compute graph is correct. for lc in loss: assert hasattr(lc, "grad_fn") loss = sum(loss) return loss
class EnergyTrainer(BaseTrainer): def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, "amp": amp, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] else: self.config["dataset"] = dataset if not is_debug: os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis if torch.cuda.is_available(): device_ids = list(range(torch.cuda.device_count())) self.output_device = self.config["optim"].get( "output_device", device_ids[0]) self.device = f"cuda:{self.output_device}" else: self.device = "cpu" print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task="is2re") def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater(self.config["optim"].get( "num_gpus", 1)) if self.config["task"]["dataset"] == "single_point_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], shuffle=True, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) self.val_loader = self.test_loader = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) else: raise NotImplementedError self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: raise NotImplementedError def load_model(self): super(EnergyTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.output_device, num_gpus=self.config["optim"].get("num_gpus", 1), ) self.model.to(self.device) def train(self): self.best_val_mae = 1e9 for epoch in range(self.config["optim"]["max_epochs"]): self.model.train() for i, batch in enumerate(self.train_loader): # Forward, loss, backward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) loss = self.scaler.scale(loss) if self.scaler else loss self._backward(loss) scale = self.scaler.get_scale() if self.scaler else 1.0 # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, metrics={}, ) self.metrics = self.evaluator.update("loss", loss.item() / scale, self.metrics) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)}) if i % self.config["cmd"]["print_every"] == 0: log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) self.scheduler.step() torch.cuda.empty_cache() if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if (val_metrics[self.evaluator.task_primary_metric["is2re"]] ["metric"] < self.best_val_mae): self.best_val_mae = val_metrics[ self.evaluator.task_primary_metric["is2re"]]["metric"] if not self.is_debug: save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, "amp": self.scaler.state_dict() if self.scaler else None, }, self.config["cmd"]["checkpoint_dir"], ) if self.test_loader is not None: self.validate(split="test", epoch=epoch) def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task="is2re"), {} loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm(enumerate(loader), total=len(loader)): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics def _forward(self, batch_list): output = self.model(batch_list) if output.shape[-1] == 1: output = output.view(-1) return { "energy": output, } def _compute_loss(self, out, batch_list): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): target_normed = self.normalizers["target"].norm(energy_target) else: target_normed = energy_target loss = self.criterion(out["energy"], target_normed) return loss def _compute_metrics(self, out, batch_list, evaluator, metrics={}): energy_target = torch.cat( [batch.y_relaxed.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", True): out["energy"] = self.normalizers["target"].denorm(out["energy"]) metrics = evaluator.eval( out, {"energy": energy_target}, prev_metrics=metrics, ) return metrics
class ForcesTrainer(BaseTrainer): def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, ): if run_dir is None: run_dir = os.getcwd() timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "checkpoint_dir": os.path.join( run_dir, "checkpoints", timestamp ), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, } if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] else: self.config["dataset"] = dataset if not is_debug: os.makedirs(self.config["cmd"]["checkpoint_dir"]) os.makedirs(self.config["cmd"]["results_dir"]) os.makedirs(self.config["cmd"]["logs_dir"]) self.is_debug = is_debug self.is_vis = is_vis if torch.cuda.is_available(): device_ids = list(range(torch.cuda.device_count())) self.output_device = self.config["optim"].get( "output_device", device_ids[0] ) self.device = f"cuda:{self.output_device}" else: self.device = "cpu" print(yaml.dump(self.config, default_flow_style=False)) self.load() self.scaler = None self.evaluator = Evaluator(task="s2ef") def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater( self.config["optim"].get("num_gpus", 1) ) if self.config["task"]["dataset"] == "trajectory_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"] )(self.config["dataset"]) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], shuffle=True, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) self.val_loader = self.test_loader = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"] )(self.config["val_dataset"]) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) else: self.dataset = registry.get_dataset_class( self.config["task"]["dataset"] )(self.config["dataset"]) ( self.train_loader, self.val_loader, self.test_loader, ) = self.dataset.get_dataloaders( batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, ) self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: self.normalizers["target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__ ], device=self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if self.config["model_attributes"].get("regress_forces", True): if self.config["dataset"].get("normalize_labels", True): if "grad_target_mean" in self.config["dataset"]: self.normalizers["grad_target"] = Normalizer( mean=self.config["dataset"]["grad_target_mean"], std=self.config["dataset"]["grad_target_std"], device=self.device, ) else: self.normalizers["grad_target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__ ], device=self.device, ) self.normalizers["grad_target"].mean.fill_(0) if self.is_vis and self.config["task"]["dataset"] != "qm9": # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots) def load_model(self): super(ForcesTrainer, self).load_model() self.model = OCPDataParallel( self.model, output_device=self.output_device, num_gpus=self.config["optim"].get("num_gpus", 1), ) self.model.to(self.device) # Takes in a new data source and generates predictions on it. def predict(self, dataset, batch_size=32): if isinstance(dataset, dict): if self.config["task"]["dataset"] == "trajectory_lmdb": print( "### Generating predictions on {}.".format(dataset["src"]) ) else: print( "### Generating predictions on {}.".format( dataset["src"] + dataset["traj"] ) ) dataset = registry.get_dataset_class( self.config["task"]["dataset"] )(dataset) data_loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=self.parallel_collater, ) elif isinstance(dataset, torch_geometric.data.Batch): data_loader = [[dataset]] else: raise NotImplementedError self.model.eval() predictions = {"energy": [], "forces": []} for i, batch_list in enumerate(data_loader): out = self._forward(batch_list) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"] ) out["forces"] = self.normalizers["grad_target"].denorm( out["forces"] ) atoms_sum = 0 predictions["energy"].extend(out["energy"].tolist()) batch_natoms = torch.cat([batch.natoms for batch in batch_list]) for natoms in batch_natoms: predictions["forces"].append( out["forces"][atoms_sum : natoms + atoms_sum] .cpu() .detach() .numpy() ) atoms_sum += natoms return predictions def train(self): self.best_val_mae = 1e9 eval_every = self.config["optim"].get("eval_every", -1) iters = 0 self.metrics = {} for epoch in range(self.config["optim"]["max_epochs"]): self.model.train() for i, batch in enumerate(self.train_loader): # Forward, loss, backward. out = self._forward(batch) loss = self._compute_loss(out, batch) self._backward(loss) # Compute metrics. self.metrics = self._compute_metrics( out, batch, self.evaluator, self.metrics, ) self.metrics = self.evaluator.update( "loss", loss.item(), self.metrics ) # Print metrics, make plots. log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} log_dict.update( {"epoch": epoch + (i + 1) / len(self.train_loader)} ) if i % self.config["cmd"]["print_every"] == 0: log_str = [ "{}: {:.4f}".format(k, v) for k, v in log_dict.items() ] print(", ".join(log_str)) self.metrics = {} if self.logger is not None: self.logger.log( log_dict, step=epoch * len(self.train_loader) + i + 1, split="train", ) iters += 1 # Evaluate on val set every `eval_every` iterations. if eval_every != -1 and iters % eval_every == 0: if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if ( val_metrics[ self.evaluator.task_primary_metric["s2ef"] ]["metric"] < self.best_val_mae ): self.best_val_mae = val_metrics[ self.evaluator.task_primary_metric["s2ef"] ]["metric"] if not self.is_debug: save_checkpoint( { "epoch": epoch + (i + 1) / len(self.train_loader), "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, }, self.config["cmd"]["checkpoint_dir"], ) self.scheduler.step() torch.cuda.empty_cache() if eval_every == -1: if self.val_loader is not None: val_metrics = self.validate(split="val", epoch=epoch) if ( val_metrics[ self.evaluator.task_primary_metric["s2ef"] ]["metric"] < self.best_val_mae ): self.best_val_mae = val_metrics[ self.evaluator.task_primary_metric["s2ef"] ]["metric"] if not self.is_debug: save_checkpoint( { "epoch": epoch + 1, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() }, "config": self.config, "val_metrics": val_metrics, }, self.config["cmd"]["checkpoint_dir"], ) if self.test_loader is not None: self.validate(split="test", epoch=epoch) if ( "relaxation_dir" in self.config["task"] and self.config["task"].get("ml_relax", "end") == "train" ): self.validate_relaxation( split="val", epoch=epoch, ) if ( "relaxation_dir" in self.config["task"] and self.config["task"].get("ml_relax", "end") == "end" ): self.validate_relaxation( split="val", epoch=epoch, ) def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() evaluator, metrics = Evaluator(task="s2ef"), {} loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm(enumerate(loader), total=len(loader)): # Forward. out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics def validate_relaxation(self, split="val", epoch=None): print("### Evaluating ML-relaxation") self.model.eval() mae_energy, mae_structure = relax_eval( trainer=self, traj_dir=self.config["task"]["relaxation_dir"], metric=self.config["task"]["metric"], steps=self.config["task"].get("relaxation_steps", 300), fmax=self.config["task"].get("relaxation_fmax", 0.01), results_dir=self.config["cmd"]["results_dir"], ) log_dict = { "relaxed_energy_mae": mae_energy, "relaxed_structure_mae": mae_structure, "epoch": epoch + 1, } # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(log_dict) return mae_energy, mae_structure def _forward(self, batch_list): if self.config["model_attributes"].get("regress_forces", True): out_energy, out_forces = self.model(batch_list) else: out_energy = self.model(batch_list) if out_energy.shape[-1] == 1: out_energy = out_energy.view(-1) out = { "energy": out_energy, } if self.config["model_attributes"].get("regress_forces", True): out["forces"] = out_forces return out def _compute_loss(self, out, batch_list): loss = [] # Energy loss. energy_target = torch.cat( [batch.y.to(self.device) for batch in batch_list], dim=0 ) if self.config["dataset"].get("normalize_labels", True): energy_target = self.normalizers["target"].norm(energy_target) energy_mult = self.config["optim"].get("energy_coefficient", 1) loss.append(energy_mult * self.criterion(out["energy"], energy_target)) # Forces loss. if self.config["model_attributes"].get("regress_forces", True): force_target = torch.cat( [batch.force.to(self.device) for batch in batch_list], dim=0 ) if self.config["dataset"].get("normalize_labels", True): force_target = self.normalizers["grad_target"].norm( force_target ) # Force coefficient = 30 has been working well for us. force_mult = self.config["optim"].get("force_coefficient", 30) if self.config["task"].get("train_on_free_atoms", False): fixed = torch.cat( [batch.fixed.to(self.device) for batch in batch_list] ) mask = fixed == 0 loss.append( force_mult * self.criterion(out["forces"][mask], force_target[mask]) ) else: loss.append( force_mult * self.criterion(out["forces"], force_target) ) # Sanity check to make sure the compute graph is correct. for lc in loss: assert hasattr(lc, "grad_fn") loss = sum(loss) return loss def _compute_metrics(self, out, batch_list, evaluator, metrics={}): target = { "energy": torch.cat( [batch.y.to(self.device) for batch in batch_list], dim=0 ), "forces": torch.cat( [batch.force.to(self.device) for batch in batch_list], dim=0 ), } if self.config["task"].get("eval_on_free_atoms", True): fixed = torch.cat( [batch.fixed.to(self.device) for batch in batch_list] ) mask = fixed == 0 out["forces"] = out["forces"][mask] target["forces"] = target["forces"][mask] if self.config["dataset"].get("normalize_labels", True): out["energy"] = self.normalizers["target"].denorm(out["energy"]) out["forces"] = self.normalizers["grad_target"].denorm( out["forces"] ) metrics = evaluator.eval(out, target, prev_metrics=metrics) return metrics