def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() meter = Meter(split=split) loader = self.val_loader if split == "val" else self.test_loader for i, batch in enumerate(loader): # Forward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) # Update meter. meter_update_dict = {"loss": loss.item()} meter_update_dict.update(metrics) meter.update(meter_update_dict) # Make plots. if self.logger is not None and epoch is not None: log_dict = meter.get_scalar_dict() log_dict.update({"epoch": epoch + 1}) self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(meter)
def validate_relaxation(self, split="val", epoch=None): print("### Evaluating ML-relaxation") self.model.eval() metrics = {} meter = Meter(split=split) mae_energy, mae_structure = relax_eval( trainer=self, traj_dir=self.config["task"]["relaxation_dir"], metric=self.config["task"]["metric"], steps=self.config["task"].get("relaxation_steps", 300), fmax=self.config["task"].get("relaxation_fmax", 0.01), results_dir=self.config["cmd"]["results_dir"], ) metrics["relaxed_energy/{}".format( self.config["task"]["metric"])] = mae_energy metrics["relaxed_structure/{}".format( self.config["task"]["metric"])] = mae_structure meter.update(metrics) # Make plots. if self.logger is not None and epoch is not None: log_dict = meter.get_scalar_dict() log_dict.update({"epoch": epoch + 1}) self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(meter) return mae_energy, mae_structure
def validate(self, split="val", epoch=None): print("### Evaluating on {}.".format(split)) self.model.eval() meter = Meter(split=split) loader = self.val_loader if split == "val" else self.test_loader for i, batch in enumerate(loader): batch = batch.to(self.device) # Forward. out, metrics = self._forward(batch) loss = self._compute_loss(out, batch) # Update meter. meter_update_dict = {"loss": loss.item()} meter_update_dict.update(metrics) meter.update(meter_update_dict) # Make plots. if self.logger is not None and epoch is not None: log_dict = meter.get_scalar_dict() log_dict.update({"epoch": epoch + 1}) self.logger.log( log_dict, step=(epoch + 1) * len(self.train_loader), split=split, ) print(meter) return ( float(meter.loss.global_avg), float(meter.meters[self.config["task"]["labels"][0] + "/" + self.config["task"]["metric"]].global_avg), )