def validate(self, split="val", epoch=None, disable_tqdm=False): if distutils.is_master(): print("### Evaluating on {}.".format(split)) if self.is_hpo: disable_tqdm = True self.model.eval() evaluator, metrics = Evaluator(task=self.name), {} rank = distutils.get_rank() loader = self.val_loader if split == "val" else self.test_loader for i, batch in tqdm( enumerate(loader), total=len(loader), position=rank, desc="device {}".format(rank), disable=disable_tqdm, ): # Forward. with torch.cuda.amp.autocast(enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) # Compute metrics. metrics = self._compute_metrics(out, batch, evaluator, metrics) metrics = evaluator.update("loss", loss.item(), metrics) aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = (aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics log_dict = {k: metrics[k]["metric"] for k in metrics} log_dict.update({"epoch": epoch + 1}) if distutils.is_master(): log_str = ["{}: {:.4f}".format(k, v) for k, v in log_dict.items()] print(", ".join(log_str)) # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) return metrics
def all_reduce(self, device): print("Total", self.total) self.total = distutils.all_reduce(self.total, device=device) self.count = distutils.all_reduce(self.count, device=device) series_list = distutils.all_gather(self.series, device=device) self.series = list(zip(series_list)) deque_list = distutils.all_gather(self.deque, device=device) self.deque = deque(list(zip(deque_list)), maxlen=self.window_size)
def forward(self, input: torch.Tensor, target: torch.Tensor): loss = self.loss_fn(input, target) if self.reduction == "mean": num_samples = input.shape[0] num_samples = distutils.all_reduce( num_samples, device=input.device ) # Multiply by world size since gradients are averaged # across DDP replicas return loss * distutils.get_world_size() / num_samples else: return loss
def validate_relaxation(self, split="val", epoch=None): print("### Evaluating ML-relaxation") self.model.eval() mae_energy, mae_structure = relax_eval( trainer=self, traj_dir=self.config["task"]["relaxation_dir"], metric=self.config["task"]["metric"], steps=self.config["task"].get("relaxation_steps", 300), fmax=self.config["task"].get("relaxation_fmax", 0.01), results_dir=self.config["cmd"]["results_dir"], ) mae_energy = distutils.all_reduce(mae_energy, average=True, device=self.device) mae_structure = distutils.all_reduce(mae_structure, average=True, device=self.device) log_dict = { "relaxed_energy_mae": mae_energy, "relaxed_structure_mae": mae_structure, "epoch": epoch + 1, } # Make plots. if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(log_dict) return mae_energy, mae_structure
def run_relaxations(self, split="val", epoch=None): print("### Running ML-relaxations") self.model.eval() evaluator, metrics = Evaluator(task="is2rs"), {} if hasattr(self.relax_dataset[0], "pos_relaxed") and hasattr( self.relax_dataset[0], "y_relaxed"): split = "val" else: split = "test" ids = [] relaxed_positions = [] for i, batch in tqdm(enumerate(self.relax_loader), total=len(self.relax_loader)): relaxed_batch = ml_relax( batch=batch, model=self, steps=self.config["task"].get("relaxation_steps", 200), fmax=self.config["task"].get("relaxation_fmax", 0.0), relax_opt=self.config["task"]["relax_opt"], device=self.device, transform=None, ) if self.config["task"].get("write_pos", False): systemids = [str(i) for i in relaxed_batch.sid.tolist()] natoms = relaxed_batch.natoms.tolist() positions = torch.split(relaxed_batch.pos, natoms) batch_relaxed_positions = [pos.tolist() for pos in positions] relaxed_positions += batch_relaxed_positions ids += systemids if split == "val": mask = relaxed_batch.fixed == 0 s_idx = 0 natoms_free = [] for natoms in relaxed_batch.natoms: natoms_free.append( torch.sum(mask[s_idx:s_idx + natoms]).item()) s_idx += natoms target = { "energy": relaxed_batch.y_relaxed, "positions": relaxed_batch.pos_relaxed[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } prediction = { "energy": relaxed_batch.y, "positions": relaxed_batch.pos[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), "natoms": torch.LongTensor(natoms_free), } metrics = evaluator.eval(prediction, target, metrics) if self.config["task"].get("write_pos", False): rank = distutils.get_rank() pos_filename = os.path.join(self.config["cmd"]["results_dir"], f"relaxed_pos_{rank}.npz") np.savez_compressed( pos_filename, ids=ids, pos=np.array(relaxed_positions, dtype=object), ) distutils.synchronize() if distutils.is_master(): gather_results = defaultdict(list) full_path = os.path.join( self.config["cmd"]["results_dir"], "relaxed_positions.npz", ) for i in range(distutils.get_world_size()): rank_path = os.path.join( self.config["cmd"]["results_dir"], f"relaxed_pos_{i}.npz", ) rank_results = np.load(rank_path, allow_pickle=True) gather_results["ids"].extend(rank_results["ids"]) gather_results["pos"].extend(rank_results["pos"]) os.remove(rank_path) # Because of how distributed sampler works, some system ids # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] gather_results["pos"] = np.array(gather_results["pos"], dtype=object)[idx] print(f"Writing results to {full_path}") np.savez_compressed(full_path, **gather_results) if split == "val": aggregated_metrics = {} for k in metrics: aggregated_metrics[k] = { "total": distutils.all_reduce(metrics[k]["total"], average=False, device=self.device), "numel": distutils.all_reduce(metrics[k]["numel"], average=False, device=self.device), } aggregated_metrics[k]["metric"] = ( aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]) metrics = aggregated_metrics # Make plots. log_dict = {k: metrics[k]["metric"] for k in metrics} if self.logger is not None and epoch is not None: self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) if distutils.is_master(): print(metrics)
def _compute_loss(self, out, batch_list): loss = [] # Energy loss. energy_target = torch.cat( [batch.y.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): energy_target = self.normalizers["target"].norm(energy_target) energy_mult = self.config["optim"].get("energy_coefficient", 1) loss.append(energy_mult * self.criterion(out["energy"], energy_target)) # Force loss. if self.config["model_attributes"].get("regress_forces", True): force_target = torch.cat( [batch.force.to(self.device) for batch in batch_list], dim=0) if self.config["dataset"].get("normalize_labels", False): force_target = self.normalizers["grad_target"].norm( force_target) tag_specific_weights = self.config["task"].get( "tag_specific_weights", []) if tag_specific_weights != []: # handle tag specific weights as introduced in forcenet assert len(tag_specific_weights) == 3 batch_tags = torch.cat( [ batch.tags.float().to(self.device) for batch in batch_list ], dim=0, ) weight = torch.zeros_like(batch_tags) weight[batch_tags == 0] = tag_specific_weights[0] weight[batch_tags == 1] = tag_specific_weights[1] weight[batch_tags == 2] = tag_specific_weights[2] loss_force_list = torch.abs(out["forces"] - force_target) train_loss_force_unnormalized = torch.sum(loss_force_list * weight.view(-1, 1)) train_loss_force_normalizer = 3.0 * weight.sum() # add up normalizer to obtain global normalizer distutils.all_reduce(train_loss_force_normalizer) # perform loss normalization before backprop train_loss_force_normalized = train_loss_force_unnormalized * ( distutils.get_world_size() / train_loss_force_normalizer) loss.append(train_loss_force_normalized) else: # Force coefficient = 30 has been working well for us. force_mult = self.config["optim"].get("force_coefficient", 30) if self.config["task"].get("train_on_free_atoms", False): fixed = torch.cat( [batch.fixed.to(self.device) for batch in batch_list]) mask = fixed == 0 loss.append(force_mult * self.criterion( out["forces"][mask], force_target[mask])) else: loss.append(force_mult * self.criterion(out["forces"], force_target)) # Sanity check to make sure the compute graph is correct. for lc in loss: assert hasattr(lc, "grad_fn") loss = sum(loss) return loss