def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater(1) if self.config["task"]["dataset"] == "single_point_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_sampler = DistributedSampler( self.train_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=True, ) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.train_sampler, ) self.val_loader = self.test_loader = None self.val_sampler = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_sampler = DistributedSampler( self.val_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.val_sampler, ) else: raise NotImplementedError self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: raise NotImplementedError
def __init__(self, config, transform=None): super(TrajectoryLmdbDataset, self).__init__() self.config = config # If running in distributed mode, only read a subset of database files world_size = distutils.get_world_size() rank = distutils.get_rank() srcdir = Path(self.config["src"]) db_paths = sorted(srcdir.glob("*.lmdb")) assert len(db_paths) > 0, f"No LMDBs found in {srcdir}" # Each process only reads a subset of the DB files. However, since the # number of DB files may not be divisible by world size, the final # (num_dbs % world_size) are shared by all processes. num_full_dbs = len(db_paths) - (len(db_paths) % world_size) full_db_paths = db_paths[rank:num_full_dbs:world_size] shared_db_paths = db_paths[num_full_dbs:] self.db_paths = full_db_paths + shared_db_paths self._keys, self.envs = [], [] for db_path in full_db_paths: self.envs.append(self.connect_db(db_path)) length = pickle.loads(self.envs[-1].begin().get( "length".encode("ascii"))) self._keys.append(list(range(length))) for db_path in shared_db_paths: self.envs.append(self.connect_db(db_path)) length = pickle.loads(self.envs[-1].begin().get( "length".encode("ascii"))) length -= length % world_size self._keys.append(list(range(rank, length, world_size))) self._keylens = [len(k) for k in self._keys] self._keylen_cumulative = np.cumsum(self._keylens).tolist() self.num_samples = sum(self._keylens) self.transform = transform
def predict(self, loader, results_file=None, disable_tqdm=False): if distutils.is_master() and not disable_tqdm: print("### Predicting on test.") assert isinstance(loader, torch.utils.data.dataloader.DataLoader) rank = distutils.get_rank() self.model.eval() if self.normalizers is not None and "target" in self.normalizers: self.normalizers["target"].to(self.device) predictions = {"id": [], "energy": []} for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) predictions["id"].extend([str(i) for i in batch[0].sid.tolist()]) predictions["energy"].extend(out["energy"].tolist()) self.save_results(predictions, results_file, keys=["energy"]) return predictions
def validate(self, split="val", epoch=None, disable_tqdm=False): if distutils.is_master(): print("### Evaluating on {}.".format(split)) if self.is_hpo: disable_tqdm = True self.model.eval() evaluator, metrics = Evaluator(task=self.name), {} rank = distutils.get_rank() loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics
def predict(self, loader, per_image=True, results_file=None, disable_tqdm=False): if distutils.is_master() and not disable_tqdm: logging.info("Predicting on test.") assert isinstance( loader, ( torch.utils.data.dataloader.DataLoader, torch_geometric.data.Batch, ), ) rank = distutils.get_rank() if isinstance(loader, torch_geometric.data.Batch): loader = [[loader]] self.model.eval() if self.ema: self.ema.store() self.ema.copy_to() if self.normalizers is not None and "target" in self.normalizers: self.normalizers["target"].to(self.device) predictions = {"id": [], "energy": []} for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) if per_image: predictions["id"].extend( [str(i) for i in batch[0].sid.tolist()]) predictions["energy"].extend(out["energy"].tolist()) else: predictions["energy"] = out["energy"].detach() return predictions self.save_results(predictions, results_file, keys=["energy"]) if self.ema: self.ema.restore() return predictions
def __init__(self, config, transform=None): super(TrajectoryLmdbDataset, self).__init__() self.config = config world_size = distutils.get_world_size() rank = distutils.get_rank() srcdir = Path(self.config["src"]) db_paths = sorted(srcdir.glob("*.lmdb")) assert len(db_paths) > 0, f"No LMDBs found in {srcdir}" # Read all LMDBs to set the size of each dataloader replica. lengths = [] for db_path in db_paths: env = self.connect_db(db_path) lengths.append( pickle.loads(env.begin().get("length".encode("ascii")))) env.close() lengths.sort(reverse=True) replica_size = sum(lengths[:math.ceil(len(lengths) / world_size)]) # Each process only reads a subset of the DB files. However, since the # number of DB files may not be divisible by world size, the final # (num_dbs % world_size) are shared by all processes. num_full_dbs = len(db_paths) - (len(db_paths) % world_size) full_db_paths = db_paths[rank:num_full_dbs:world_size] shared_db_paths = db_paths[num_full_dbs:] self.db_paths = full_db_paths + shared_db_paths self._keys, self.envs = [], [] for db_path in full_db_paths: self.envs.append(self.connect_db(db_path)) length = pickle.loads(self.envs[-1].begin().get( "length".encode("ascii"))) self._keys.append(list(range(length))) for db_path in shared_db_paths: self.envs.append(self.connect_db(db_path)) length = pickle.loads(self.envs[-1].begin().get( "length".encode("ascii"))) length -= length % world_size self._keys.append(list(range(rank, length, world_size))) keylens = [len(k) for k in self._keys] # Need to pad dataloaders so all have the same no. of samples. # This means that dataloaders will have some repeated samples # that need to be pruned out in post-processing. if sum(keylens) < replica_size: self._keys[-1].extend([self._keys[-1][-1]] * (replica_size - sum(keylens))) keylens = [len(k) for k in self._keys] self._keylen_cumulative = np.cumsum(keylens).tolist() self.transform = transform self.num_samples = sum(keylens) assert self.num_samples == replica_size
def get_sampler(self, dataset, batch_size, shuffle): if "load_balancing" in self.config["optim"]: balancing_mode = self.config["optim"]["load_balancing"] force_balancing = True else: balancing_mode = "atoms" force_balancing = False sampler = BalancedBatchSampler( dataset, batch_size=batch_size, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), device=self.device, mode=balancing_mode, shuffle=shuffle, force_balancing=force_balancing, ) return sampler
def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater( 1 if not self.cpu else 0, self.config["model_attributes"].get("otf_graph", False), ) if self.config["task"]["dataset"] == "trajectory_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], shuffle=True, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) self.val_loader = self.test_loader = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) if "test_dataset" in self.config: self.test_dataset = registry.get_dataset_class( self.config["task"]["dataset"])( self.config["test_dataset"]) self.test_loader = DataLoader( self.test_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) if "relax_dataset" in self.config["task"]: assert os.path.isfile( self.config["task"]["relax_dataset"]["src"]) self.relax_dataset = registry.get_dataset_class( "single_point_lmdb")(self.config["task"]["relax_dataset"]) self.relax_sampler = DistributedSampler( self.relax_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.relax_loader = DataLoader( self.relax_dataset, batch_size=self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.relax_sampler, ) else: self.dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) ( self.train_loader, self.val_loader, self.test_loader, ) = self.dataset.get_dataloaders( batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, ) self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", False): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: self.normalizers["target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if self.config["model_attributes"].get("regress_forces", True): if self.config["dataset"].get("normalize_labels", False): if "grad_target_mean" in self.config["dataset"]: self.normalizers["grad_target"] = Normalizer( mean=self.config["dataset"]["grad_target_mean"], std=self.config["dataset"]["grad_target_std"], device=self.device, ) else: self.normalizers["grad_target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) self.normalizers["grad_target"].mean.fill_(0) if (self.is_vis and self.config["task"]["dataset"] != "qm9" and distutils.is_master()): # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots)
def run_relaxations(self, split="val", epoch=None): print("### Running ML-relaxations") self.model.eval() evaluator, metrics = Evaluator(task="is2rs"), {} if hasattr(self.relax_dataset[0], "pos_relaxed") and hasattr( self.relax_dataset[0], "y_relaxed"): split = "val" else: split = "test" ids = [] relaxed_positions = [] for i, batch in tqdm(enumerate(self.relax_loader), total=len(self.relax_loader)): relaxed_batch = ml_relax( batch=batch, model=self, steps=self.config["task"].get("relaxation_steps", 200), fmax=self.config["task"].get("relaxation_fmax", 0.0), relax_opt=self.config["task"]["relax_opt"], device=self.device, transform=None, ) if self.config["task"].get("write_pos", False): systemids = [str(i) for i in relaxed_batch.sid.tolist()] natoms = relaxed_batch.natoms.tolist() positions = torch.split(relaxed_batch.pos, natoms) batch_relaxed_positions = [pos.tolist() for pos in positions] relaxed_positions += batch_relaxed_positions ids += systemids if split == "val": mask = relaxed_batch.fixed == 0 s_idx = 0 natoms_free = [] for natoms in relaxed_batch.natoms: natoms_free.append( torch.sum(mask[s_idx:s_idx + natoms]).item()) s_idx += natoms target = { "energy": relaxed_batch.y_relaxed, "positions": relaxed_batch.pos_relaxed[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } prediction = { "energy": relaxed_batch.y, "positions": relaxed_batch.pos[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } metrics = evaluator.eval(prediction, target, metrics) if self.config["task"].get("write_pos", False): rank = distutils.get_rank() pos_filename = os.path.join(self.config["cmd"]["results_dir"], f"relaxed_pos_{rank}.npz") np.savez_compressed( pos_filename, ids=ids, pos=np.array(relaxed_positions, dtype=object), ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"relaxed_pos_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) gather_results["pos"].extend(rank_results["pos"]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] gather_results["pos"] = np.array(gather_results["pos"], dtype=object)[idx] print(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results) if split == "val": aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = ( aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics # Make plots. log_dict = {k: metrics[k]["metric"] for k in metrics} if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) if distutils.is_master(): print(metrics)
def predict(self, data_loader, per_image=True, results_file=None, disable_tqdm=True): if distutils.is_master() and not disable_tqdm: print("### Predicting on test.") assert isinstance( data_loader, ( torch.utils.data.dataloader.DataLoader, torch_geometric.data.Batch, ), ) rank = distutils.get_rank() if isinstance(data_loader, torch_geometric.data.Batch): data_loader = [[data_loader]] self.model.eval() if self.normalizers is not None and "target" in self.normalizers: self.normalizers["target"].to(self.device) self.normalizers["grad_target"].to(self.device) predictions = {"id": [], "energy": [], "forces": []} for i, batch_list in tqdm( enumerate(data_loader), total=len(data_loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch_list) if self.normalizers is not None and "target" in self.normalizers: out["energy"] = self.normalizers["target"].denorm( out["energy"]) out["forces"] = self.normalizers["grad_target"].denorm( out["forces"]) if per_image: atoms_sum = 0 systemids = [ str(i) + "_" + str(j) for i, j in zip( batch_list[0].sid.tolist(), batch_list[0].fid.tolist()) ] predictions["id"].extend(systemids) predictions["energy"].extend(out["energy"].to( torch.float16).tolist()) batch_natoms = torch.cat( [batch.natoms for batch in batch_list]) batch_fixed = torch.cat([batch.fixed for batch in batch_list]) for natoms in batch_natoms: forces = (out["forces"][atoms_sum:natoms + atoms_sum].cpu().detach().to( torch.float16).numpy()) # evalAI only requires forces on free atoms if results_file is not None: _free_atoms = (batch_fixed[atoms_sum:natoms + atoms_sum] == 0).tolist() forces = forces[_free_atoms] atoms_sum += natoms predictions["forces"].append(forces) else: predictions["energy"] = out["energy"].detach() predictions["forces"] = out["forces"].detach() return predictions predictions["forces"] = np.array(predictions["forces"], dtype=object) predictions["energy"] = np.array(predictions["energy"]) predictions["id"] = np.array(predictions["id"]) self.save_results(predictions, results_file, keys=["energy", "forces"]) return predictions