def evaluate(model, loader, loss_fn, device, return_results=True, loss_is_normalized=True, submodel=None, **kwargs): """Evaluate the current state of the model using a given dataloader """ model.eval() model.to(device) eval_loss = 0.0 n_eval = 0 all_results = [] all_batches = [] for batch in loader: # append batch_size batch = batch_to(batch, device) vsize = batch['nxyz'].size(0) n_eval += vsize # e.g. if the result is a sum of results from two models, and you just # want the prediction of one of those models if submodel is not None: results = getattr(model, submodel)(batch) else: results = model(batch, **kwargs) eval_batch_loss = loss_fn(batch, results).data.cpu().numpy() if loss_is_normalized: eval_loss += eval_batch_loss * vsize else: eval_loss += eval_batch_loss all_results.append(batch_detach(results)) all_batches.append(batch_detach(batch)) # del results # del batch # weighted average over batches if loss_is_normalized: eval_loss /= n_eval if not return_results: return {}, {}, eval_loss else: # this step can be slow, all_results = concatenate_dict(*all_results) all_batches = concatenate_dict(*all_batches) return all_results, all_batches, eval_loss
def validate(self, device): """Validate the current state of the model using the validation set """ self._model.eval() for h in self.hooks: h.on_validation_begin(self) val_loss = 0.0 n_val = 0 for val_batch in self.validation_loader: val_batch = batch_to(val_batch, device) # append batch_size vsize = val_batch['nxyz'].size(0) n_val += vsize for h in self.hooks: h.on_validation_batch_begin(self) # move input to gpu, if needed results = self._model(val_batch) val_batch_loss = self.loss_fn(val_batch, results).data.cpu().numpy() if self.loss_is_normalized: val_loss += val_batch_loss * vsize else: val_loss += val_batch_loss for h in self.hooks: h.on_validation_batch_end(self, val_batch, results) # weighted average over batches if self.loss_is_normalized: val_loss /= n_val if self.best_loss > val_loss: self.best_loss = val_loss torch.save(self._model, self.best_model) for h in self.hooks: h.on_validation_end(self, val_loss)
def calculate( self, atoms=None, properties=['energy', 'forces'], system_changes=all_changes, ): """Calculates the desired properties for the given AtomsBatch. Args: atoms (AtomsBatch): custom Atoms subclass that contains implementation of neighbor lists, batching and so on. Avoids the use of the Dataset to calculate using the models created. properties (list of str): 'energy', 'forces' or both system_changes (default from ase) """ Calculator.calculate(self, atoms, properties, system_changes) # run model #atomsbatch = AtomsBatch(atoms) # batch_to(atomsbatch.get_batch(), self.device) batch = batch_to(atoms.get_batch(), self.device) # add keys so that the readout function can calculate these properties batch['energy'] = [] if 'forces' in properties: batch['energy_grad'] = [] prediction = self.model(batch) # change energy and force to numpy array energy = prediction['energy'].detach().cpu().numpy() * ( 1 / const.EV_TO_KCAL_MOL) energy_grad = prediction['energy_grad'].detach().cpu().numpy() * ( 1 / const.EV_TO_KCAL_MOL) self.results = { 'energy': energy.reshape(-1), } if 'forces' in properties: self.results['forces'] = -energy_grad.reshape(-1, 3)
def evaluate(model, loader, device, track, **kwargs): """ Evaluate a model on a dataset. Args: model (nff.nn.models): original NFF model loaded loader (torch.utils.data.DataLoader): data loader device (Union[str, int]): device on which you run the model Returns: all_results (dict): dictionary of results all_batches (dict): dictionary of ground truth """ model.eval() model.to(device) all_results = [] all_batches = [] iter_func = get_iter_func(track) for batch in iter_func(loader): batch = batch_to(batch, device) results = fps_and_pred(model, batch, **kwargs) all_results.append(batch_detach(results)) # don't overload memory with unnecessary keys reduced_batch = { key: val for key, val in batch.items() if key not in ['bond_idx', 'ji_idx', 'kj_idx', 'nbr_list', 'bonded_nbr_list'] } all_batches.append(batch_detach(reduced_batch)) all_results = concatenate_dict(*all_results) all_batches = concatenate_dict(*all_batches) return all_results, all_batches
def train(self, device, n_epochs=MAX_EPOCHS): """Train the model for the given number of epochs on a specified device. Args: device (torch.torch.Device): device on which training takes place. n_epochs (int): number of training epochs. Note: Depending on the `hooks`, training can stop earlier than `n_epochs`. """ self.to(device) self._stop = False # initialize loss, num_batches, and optimizer grad to 0 loss = torch.tensor(0.0).to(device) num_batches = 0 self.optimizer.zero_grad() for h in self.hooks: h.on_train_begin(self) if hasattr(h, "mini_batches"): h.mini_batches = self.mini_batches try: for _ in range(n_epochs): self._model.train() self.epoch += 1 for h in self.hooks: h.on_epoch_begin(self) if self._stop: break for j, batch in enumerate(self.train_loader): batch = batch_to(batch, device) for h in self.hooks: h.on_batch_begin(self, batch) results = self._model(batch) loss += self.loss_fn(batch, results) self.step += 1 # update the loss self.minibatches number # of times before taking a step num_batches += 1 if num_batches == self.mini_batches: loss.backward() self.optimizer.step() for h in self.hooks: h.on_batch_end(self, batch, results, loss) # reset loss, num_batches, and the optimizer grad loss = torch.tensor(0.0).to(device) num_batches = 0 self.optimizer.zero_grad() if self._stop: break if self.epoch % self.checkpoint_interval == 0: self.store_checkpoint() # validation if self.epoch % self.validation_interval == 0 or self._stop: self.validate(device) for h in self.hooks: h.on_epoch_end(self) if self._stop: break # Training Ends # run hooks & store checkpoint for h in self.hooks: h.on_train_ends(self) self.store_checkpoint() except Exception as e: for h in self.hooks: h.on_train_failed(self) raise e
def validate(self, device, test=False): """Validate the current state of the model using the validation set """ self._model.eval() for h in self.hooks: h.on_validation_begin(self) val_loss = 0.0 n_val = 0 for val_batch in self.validation_loader: val_batch = batch_to(val_batch, device) # append batch_size if self.mol_loss_norm: vsize = len(val_batch["num_atoms"]) elif self.loss_is_normalized: vsize = val_batch['nxyz'].size(0) n_val += vsize for h in self.hooks: h.on_validation_batch_begin(self) results = self.call_model(val_batch, train=False) # detach from the graph results = batch_to(batch_detach(results), device) val_batch_loss = self.loss_fn( val_batch, results).data.cpu().numpy() if self.loss_is_normalized or self.mol_loss_norm: val_loss += val_batch_loss * vsize else: val_loss += val_batch_loss for h in self.hooks: h.on_validation_batch_end(self, val_batch, results) if test: return # weighted average over batches if self.loss_is_normalized or self.mol_loss_norm: val_loss /= n_val # if running in parallel, save the validation loss # and pick up the losses from the other processes too if self.parallel: self.save_val_loss(val_loss, n_val) val_loss = self.load_val_loss() for h in self.hooks: # delay this until after we know what the real # val loss is (e.g. if it's from a metric) if isinstance(h, ReduceLROnPlateauHook): continue h.on_validation_end(self, val_loss) metric_dic = getattr(h, "metric_dic", None) if metric_dic is None: continue if self.metric_as_loss in metric_dic: val_loss = metric_dic[self.metric_as_loss] if self.metric_objective.lower() == "maximize": val_loss *= -1 for h in self.hooks: if not isinstance(h, ReduceLROnPlateauHook): continue h.on_validation_end(self, val_loss) if self.best_loss > val_loss: self.best_loss = val_loss self.save_as_best()
def train(self, device, n_epochs=MAX_EPOCHS): """Train the model for the given number of epochs on a specified device. Args: device (torch.torch.Device): device on which training takes place. n_epochs (int): number of training epochs. Note: Depending on the `hooks`, training can stop earlier than `n_epochs`. """ self.to(device) self._stop = False # initialize loss, num_batches, and optimizer grad to 0 loss = torch.tensor(0.0).to(device) num_batches = 0 self.optimizer.zero_grad() self.save_as_best() for h in self.hooks: h.on_train_begin(self) if hasattr(h, "mini_batches"): h.mini_batches = self.mini_batches try: for _ in range(n_epochs): self._model.train() self.epoch += 1 for h in self.hooks: h.on_epoch_begin(self) if self._stop: break for j, batch in self.tqdm_enum(self.train_loader): batch = batch_to(batch, device) for h in self.hooks: h.on_batch_begin(self, batch) results = self.call_model(batch, train=True) mini_loss = self.get_loss(batch, results) self.loss_backward(mini_loss) if not torch.isnan(mini_loss): loss += mini_loss.cpu().detach().to(device) self.step += 1 # update the loss self.minibatches number # of times before taking a step num_batches += 1 if num_batches == self.mini_batches: loss /= self.nloss num_batches = 0 # effective number of batches so far eff_batches = int((j + 1) / self.mini_batches) self.optim_step(batch_num=eff_batches, device=device) for h in self.hooks: h.on_batch_end(self, batch, results, loss) # reset loss and the optimizer grad loss = torch.tensor(0.0).to(device) self.optimizer.zero_grad() if any((self.batch_stop, self._stop, j == self.epoch_cutoff)): break # reset for next epoch del mini_loss num_batches = 0 loss = torch.tensor(0.0).to(device) self.optimizer.zero_grad() # store the checkpoint only if this is the base model, # otherwise it will get stored unnecessarily from other # gpus, which will cause IO issues if (self.epoch % self.checkpoint_interval == 0 and self.base): self.store_checkpoint() # validation if (self.epoch % self.validation_interval == 0 or self._stop): self.validate(device) for h in self.hooks: h.on_epoch_end(self) # Training Ends # run hooks & store checkpoint for h in self.hooks: h.on_train_ends(self) if self.base: self.store_checkpoint() except Exception as e: for h in self.hooks: h.on_train_failed(self) raise e