def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater(1) if self.config["task"]["dataset"] == "single_point_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_sampler = DistributedSampler( self.train_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=True, ) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.train_sampler, ) self.val_loader = self.test_loader = None self.val_sampler = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_sampler = DistributedSampler( self.val_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.val_sampler, ) else: raise NotImplementedError self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", True): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: raise NotImplementedError
def __init__(self, config, transform=None): super(TrajectoryLmdbDataset, self).__init__() self.config = config # If running in distributed mode, only read a subset of database files world_size = distutils.get_world_size() rank = distutils.get_rank() srcdir = Path(self.config["src"]) db_paths = sorted(srcdir.glob("*.lmdb")) assert len(db_paths) > 0, f"No LMDBs found in {srcdir}" # Each process only reads a subset of the DB files. However, since the # number of DB files may not be divisible by world size, the final # (num_dbs % world_size) are shared by all processes. num_full_dbs = len(db_paths) - (len(db_paths) % world_size) full_db_paths = db_paths[rank:num_full_dbs:world_size] shared_db_paths = db_paths[num_full_dbs:] self.db_paths = full_db_paths + shared_db_paths self._keys, self.envs = [], [] for db_path in full_db_paths: self.envs.append(self.connect_db(db_path)) length = pickle.loads(self.envs[-1].begin().get( "length".encode("ascii"))) self._keys.append(list(range(length))) for db_path in shared_db_paths: self.envs.append(self.connect_db(db_path)) length = pickle.loads(self.envs[-1].begin().get( "length".encode("ascii"))) length -= length % world_size self._keys.append(list(range(rank, length, world_size))) self._keylens = [len(k) for k in self._keys] self._keylen_cumulative = np.cumsum(self._keylens).tolist() self.num_samples = sum(self._keylens) self.transform = transform
def __init__(self, config, transform=None): super(TrajectoryLmdbDataset, self).__init__() self.config = config world_size = distutils.get_world_size() rank = distutils.get_rank() srcdir = Path(self.config["src"]) db_paths = sorted(srcdir.glob("*.lmdb")) assert len(db_paths) > 0, f"No LMDBs found in {srcdir}" # Read all LMDBs to set the size of each dataloader replica. lengths = [] for db_path in db_paths: env = self.connect_db(db_path) lengths.append( pickle.loads(env.begin().get("length".encode("ascii")))) env.close() lengths.sort(reverse=True) replica_size = sum(lengths[:math.ceil(len(lengths) / world_size)]) # Each process only reads a subset of the DB files. However, since the # number of DB files may not be divisible by world size, the final # (num_dbs % world_size) are shared by all processes. num_full_dbs = len(db_paths) - (len(db_paths) % world_size) full_db_paths = db_paths[rank:num_full_dbs:world_size] shared_db_paths = db_paths[num_full_dbs:] self.db_paths = full_db_paths + shared_db_paths self._keys, self.envs = [], [] for db_path in full_db_paths: self.envs.append(self.connect_db(db_path)) length = pickle.loads(self.envs[-1].begin().get( "length".encode("ascii"))) self._keys.append(list(range(length))) for db_path in shared_db_paths: self.envs.append(self.connect_db(db_path)) length = pickle.loads(self.envs[-1].begin().get( "length".encode("ascii"))) length -= length % world_size self._keys.append(list(range(rank, length, world_size))) keylens = [len(k) for k in self._keys] # Need to pad dataloaders so all have the same no. of samples. # This means that dataloaders will have some repeated samples # that need to be pruned out in post-processing. if sum(keylens) < replica_size: self._keys[-1].extend([self._keys[-1][-1]] * (replica_size - sum(keylens))) keylens = [len(k) for k in self._keys] self._keylen_cumulative = np.cumsum(keylens).tolist() self.transform = transform self.num_samples = sum(keylens) assert self.num_samples == replica_size
def forward(self, input: torch.Tensor, target: torch.Tensor): loss = self.loss_fn(input, target) if self.reduction == "mean": num_samples = input.shape[0] num_samples = distutils.all_reduce( num_samples, device=input.device ) # Multiply by world size since gradients are averaged # across DDP replicas return loss * distutils.get_world_size() / num_samples else: return loss
def save_results(self, predictions, results_file, keys): if results_file is None: return results_file_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{distutils.get_rank()}.npz", ) np.savez_compressed( results_file_path, ids=predictions["id"], **{key: predictions[key] for key in keys}, ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"{self.name}_{results_file}_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) for key in keys: gather_results[key].extend(rank_results[key]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] for k in keys: if k == "forces": gather_results[k] = np.concatenate( np.array(gather_results[k])[idx] ) elif k == "chunk_idx": gather_results[k] = np.cumsum( np.array(gather_results[k])[idx] )[:-1] else: gather_results[k] = np.array(gather_results[k])[idx] logging.info(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results)
def get_sampler(self, dataset, batch_size, shuffle): if "load_balancing" in self.config["optim"]: balancing_mode = self.config["optim"]["load_balancing"] force_balancing = True else: balancing_mode = "atoms" force_balancing = False sampler = BalancedBatchSampler( dataset, batch_size=batch_size, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), device=self.device, mode=balancing_mode, shuffle=shuffle, force_balancing=force_balancing, ) return sampler
def load_task(self): print("### Loading dataset: {}".format(self.config["task"]["dataset"])) self.parallel_collater = ParallelCollater( 1 if not self.cpu else 0, self.config["model_attributes"].get("otf_graph", False), ) if self.config["task"]["dataset"] == "trajectory_lmdb": self.train_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) self.train_loader = DataLoader( self.train_dataset, batch_size=self.config["optim"]["batch_size"], shuffle=True, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) self.val_loader = self.test_loader = None if "val_dataset" in self.config: self.val_dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["val_dataset"]) self.val_loader = DataLoader( self.val_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) if "test_dataset" in self.config: self.test_dataset = registry.get_dataset_class( self.config["task"]["dataset"])( self.config["test_dataset"]) self.test_loader = DataLoader( self.test_dataset, self.config["optim"].get("eval_batch_size", 64), shuffle=False, collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, ) if "relax_dataset" in self.config["task"]: assert os.path.isfile( self.config["task"]["relax_dataset"]["src"]) self.relax_dataset = registry.get_dataset_class( "single_point_lmdb")(self.config["task"]["relax_dataset"]) self.relax_sampler = DistributedSampler( self.relax_dataset, num_replicas=distutils.get_world_size(), rank=distutils.get_rank(), shuffle=False, ) self.relax_loader = DataLoader( self.relax_dataset, batch_size=self.config["optim"].get("eval_batch_size", 64), collate_fn=self.parallel_collater, num_workers=self.config["optim"]["num_workers"], pin_memory=True, sampler=self.relax_sampler, ) else: self.dataset = registry.get_dataset_class( self.config["task"]["dataset"])(self.config["dataset"]) ( self.train_loader, self.val_loader, self.test_loader, ) = self.dataset.get_dataloaders( batch_size=self.config["optim"]["batch_size"], collate_fn=self.parallel_collater, ) self.num_targets = 1 # Normalizer for the dataset. # Compute mean, std of training set labels. self.normalizers = {} if self.config["dataset"].get("normalize_labels", False): if "target_mean" in self.config["dataset"]: self.normalizers["target"] = Normalizer( mean=self.config["dataset"]["target_mean"], std=self.config["dataset"]["target_std"], device=self.device, ) else: self.normalizers["target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) # If we're computing gradients wrt input, set mean of normalizer to 0 -- # since it is lost when compute dy / dx -- and std to forward target std if self.config["model_attributes"].get("regress_forces", True): if self.config["dataset"].get("normalize_labels", False): if "grad_target_mean" in self.config["dataset"]: self.normalizers["grad_target"] = Normalizer( mean=self.config["dataset"]["grad_target_mean"], std=self.config["dataset"]["grad_target_std"], device=self.device, ) else: self.normalizers["grad_target"] = Normalizer( tensor=self.train_loader.dataset.data.y[ self.train_loader.dataset.__indices__], device=self.device, ) self.normalizers["grad_target"].mean.fill_(0) if (self.is_vis and self.config["task"]["dataset"] != "qm9" and distutils.is_master()): # Plot label distribution. plots = [ plot_histogram( self.train_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: train", ), plot_histogram( self.val_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: val", ), plot_histogram( self.test_loader.dataset.data.y.tolist(), xlabel="{}/raw".format(self.config["task"]["labels"][0]), ylabel="# Examples", title="Split: test", ), ] self.logger.log_plots(plots)
def run_relaxations(self, split="val", epoch=None): print("### Running ML-relaxations") self.model.eval() evaluator, metrics = Evaluator(task="is2rs"), {} if hasattr(self.relax_dataset[0], "pos_relaxed") and hasattr( self.relax_dataset[0], "y_relaxed"): split = "val" else: split = "test" ids = [] relaxed_positions = [] for i, batch in tqdm(enumerate(self.relax_loader), total=len(self.relax_loader)): relaxed_batch = ml_relax( batch=batch, model=self, steps=self.config["task"].get("relaxation_steps", 200), fmax=self.config["task"].get("relaxation_fmax", 0.0), relax_opt=self.config["task"]["relax_opt"], device=self.device, transform=None, ) if self.config["task"].get("write_pos", False): systemids = [str(i) for i in relaxed_batch.sid.tolist()] natoms = relaxed_batch.natoms.tolist() positions = torch.split(relaxed_batch.pos, natoms) batch_relaxed_positions = [pos.tolist() for pos in positions] relaxed_positions += batch_relaxed_positions ids += systemids if split == "val": mask = relaxed_batch.fixed == 0 s_idx = 0 natoms_free = [] for natoms in relaxed_batch.natoms: natoms_free.append( torch.sum(mask[s_idx:s_idx + natoms]).item()) s_idx += natoms target = { "energy": relaxed_batch.y_relaxed, "positions": relaxed_batch.pos_relaxed[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } prediction = { "energy": relaxed_batch.y, "positions": relaxed_batch.pos[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } metrics = evaluator.eval(prediction, target, metrics) if self.config["task"].get("write_pos", False): rank = distutils.get_rank() pos_filename = os.path.join(self.config["cmd"]["results_dir"], f"relaxed_pos_{rank}.npz") np.savez_compressed( pos_filename, ids=ids, pos=np.array(relaxed_positions, dtype=object), ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"relaxed_pos_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) gather_results["pos"].extend(rank_results["pos"]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] gather_results["pos"] = np.array(gather_results["pos"], dtype=object)[idx] print(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results) if split == "val": aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = ( aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics # Make plots. log_dict = {k: metrics[k]["metric"] for k in metrics} if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) if distutils.is_master(): print(metrics)
def __init__( self, task, model, dataset, optimizer, identifier, run_dir=None, is_debug=False, is_vis=False, is_hpo=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, cpu=False, name="base_trainer", ): self.name = name self.cpu = cpu self.start_step = 0 if torch.cuda.is_available() and not self.cpu: self.device = local_rank else: self.device = "cpu" self.cpu = True # handle case when `--cpu` isn't specified # but there are no gpu devices available if run_dir is None: run_dir = os.getcwd() timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp( timestamp.int()).strftime("%Y-%m-%d-%H-%M-%S") if identifier: timestamp += "-{}".format(identifier) try: commit_hash = (subprocess.check_output([ "git", "-C", ocpmodels.__path__[0], "describe", "--always", ]).strip().decode("ascii")) # catch instances where code is not being run from a git repo except Exception: commit_hash = None self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "gpus": distutils.get_world_size() if not self.cpu else 0, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp": timestamp, "commit": commit_hash, "checkpoint_dir": os.path.join(run_dir, "checkpoints", timestamp), "results_dir": os.path.join(run_dir, "results", timestamp), "logs_dir": os.path.join(run_dir, "logs", logger, timestamp), }, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if isinstance(dataset, list): self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] else: self.config["dataset"] = dataset if not is_debug and distutils.is_master() and not is_hpo: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis self.is_hpo = is_hpo if self.is_hpo: # sets the hpo checkpoint frequency # default is no checkpointing self.hpo_checkpoint_every = self.config["optim"].get( "checkpoint_every", -1) if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name)
def _compute_loss(self, out, batch_list): loss = [] # Energy loss. energy_target = torch.cat( [batch.y.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): energy_target = self.normalizers["target"].norm(energy_target) energy_mult = self.config["optim"].get("energy_coefficient", 1) loss.append(energy_mult * self.criterion(out["energy"], energy_target)) # Force loss. if self.config["model_attributes"].get("regress_forces", True): force_target = torch.cat( [batch.force.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): force_target = self.normalizers["grad_target"].norm( force_target) tag_specific_weights = self.config["task"].get( "tag_specific_weights", []) if tag_specific_weights != []: # handle tag specific weights as introduced in forcenet assert len(tag_specific_weights) == 3 batch_tags = torch.cat( [ batch.tags.float().to(self.device) for batch in batch_list ], dim=0, ) weight = torch.zeros_like(batch_tags) weight[batch_tags == 0] = tag_specific_weights[0] weight[batch_tags == 1] = tag_specific_weights[1] weight[batch_tags == 2] = tag_specific_weights[2] loss_force_list = torch.abs(out["forces"] - force_target) train_loss_force_unnormalized = torch.sum(loss_force_list * weight.view(-1, 1)) train_loss_force_normalizer = 3.0 * weight.sum() # add up normalizer to obtain global normalizer distutils.all_reduce(train_loss_force_normalizer) # perform loss normalization before backprop train_loss_force_normalized = train_loss_force_unnormalized * ( distutils.get_world_size() / train_loss_force_normalizer) loss.append(train_loss_force_normalized) else: # Force coefficient = 30 has been working well for us. force_mult = self.config["optim"].get("force_coefficient", 30) if self.config["task"].get("train_on_free_atoms", False): fixed = torch.cat( [batch.fixed.to(self.device) for batch in batch_list]) mask = fixed == 0 loss.append(force_mult * self.criterion( out["forces"][mask], force_target[mask])) else: loss.append(force_mult * self.criterion(out["forces"], force_target)) # Sanity check to make sure the compute graph is correct. for lc in loss: assert hasattr(lc, "grad_fn") loss = sum(loss) return loss
def __init__( self, task, model, dataset, optimizer, identifier, normalizer=None, timestamp_id=None, run_dir=None, is_debug=False, is_vis=False, is_hpo=False, print_every=100, seed=None, logger="tensorboard", local_rank=0, amp=False, cpu=False, name="base_trainer", slurm={}, ): self.name = name self.cpu = cpu self.epoch = 0 self.step = 0 if torch.cuda.is_available() and not self.cpu: self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") self.cpu = True # handle case when `--cpu` isn't specified # but there are no gpu devices available if run_dir is None: run_dir = os.getcwd() if timestamp_id is None: timestamp = torch.tensor(datetime.datetime.now().timestamp()).to( self.device ) # create directories from master rank only distutils.broadcast(timestamp, 0) timestamp = datetime.datetime.fromtimestamp( timestamp.int() ).strftime("%Y-%m-%d-%H-%M-%S") if identifier: self.timestamp_id = f"{timestamp}-{identifier}" else: self.timestamp_id = timestamp else: self.timestamp_id = timestamp_id try: commit_hash = ( subprocess.check_output( [ "git", "-C", ocpmodels.__path__[0], "describe", "--always", ] ) .strip() .decode("ascii") ) # catch instances where code is not being run from a git repo except Exception: commit_hash = None self.config = { "task": task, "model": model.pop("name"), "model_attributes": model, "optim": optimizer, "logger": logger, "amp": amp, "gpus": distutils.get_world_size() if not self.cpu else 0, "cmd": { "identifier": identifier, "print_every": print_every, "seed": seed, "timestamp_id": self.timestamp_id, "commit": commit_hash, "checkpoint_dir": os.path.join( run_dir, "checkpoints", self.timestamp_id ), "results_dir": os.path.join( run_dir, "results", self.timestamp_id ), "logs_dir": os.path.join( run_dir, "logs", logger, self.timestamp_id ), }, "slurm": slurm, } # AMP Scaler self.scaler = torch.cuda.amp.GradScaler() if amp else None if "SLURM_JOB_ID" in os.environ and "folder" in self.config["slurm"]: self.config["slurm"]["job_id"] = os.environ["SLURM_JOB_ID"] self.config["slurm"]["folder"] = self.config["slurm"][ "folder" ].replace("%j", self.config["slurm"]["job_id"]) if isinstance(dataset, list): if len(dataset) > 0: self.config["dataset"] = dataset[0] if len(dataset) > 1: self.config["val_dataset"] = dataset[1] if len(dataset) > 2: self.config["test_dataset"] = dataset[2] elif isinstance(dataset, dict): self.config["dataset"] = dataset.get("train", None) self.config["val_dataset"] = dataset.get("val", None) self.config["test_dataset"] = dataset.get("test", None) else: self.config["dataset"] = dataset self.normalizer = normalizer # This supports the legacy way of providing norm parameters in dataset if self.config.get("dataset", None) is not None and normalizer is None: self.normalizer = self.config["dataset"] if not is_debug and distutils.is_master() and not is_hpo: os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["results_dir"], exist_ok=True) os.makedirs(self.config["cmd"]["logs_dir"], exist_ok=True) self.is_debug = is_debug self.is_vis = is_vis self.is_hpo = is_hpo if self.is_hpo: # conditional import is necessary for checkpointing from ray import tune from ocpmodels.common.hpo_utils import tune_reporter # sets the hpo checkpoint frequency # default is no checkpointing self.hpo_checkpoint_every = self.config["optim"].get( "checkpoint_every", -1 ) if distutils.is_master(): print(yaml.dump(self.config, default_flow_style=False)) self.load() self.evaluator = Evaluator(task=name)