Beispiel #1
0
def evaluate(model,
             loader,
             loss_fn,
             device,
             return_results=True,
             loss_is_normalized=True,
             submodel=None,
             **kwargs):
    """Evaluate the current state of the model using a given dataloader
    """

    model.eval()
    model.to(device)

    eval_loss = 0.0
    n_eval = 0

    all_results = []
    all_batches = []

    for batch in loader:
        # append batch_size
        batch = batch_to(batch, device)

        vsize = batch['nxyz'].size(0)
        n_eval += vsize

        # e.g. if the result is a sum of results from two models, and you just
        # want the prediction of one of those models
        if submodel is not None:
            results = getattr(model, submodel)(batch)
        else:
            results = model(batch, **kwargs)

        eval_batch_loss = loss_fn(batch, results).data.cpu().numpy()

        if loss_is_normalized:
            eval_loss += eval_batch_loss * vsize
        else:
            eval_loss += eval_batch_loss

        all_results.append(batch_detach(results))
        all_batches.append(batch_detach(batch))

        # del results
        # del batch

    # weighted average over batches
    if loss_is_normalized:
        eval_loss /= n_eval

    if not return_results:
        return {}, {}, eval_loss

    else:
        # this step can be slow,
        all_results = concatenate_dict(*all_results)
        all_batches = concatenate_dict(*all_batches)

        return all_results, all_batches, eval_loss
Beispiel #2
0
    def validate(self, device):
        """Validate the current state of the model using the validation set
        """

        self._model.eval()

        for h in self.hooks:
            h.on_validation_begin(self)

        val_loss = 0.0
        n_val = 0

        for val_batch in self.validation_loader:

            val_batch = batch_to(val_batch, device)

            # append batch_size
            vsize = val_batch['nxyz'].size(0)
            n_val += vsize

            for h in self.hooks:
                h.on_validation_batch_begin(self)

            # move input to gpu, if needed
            results = self._model(val_batch)

            val_batch_loss = self.loss_fn(val_batch,
                                          results).data.cpu().numpy()

            if self.loss_is_normalized:
                val_loss += val_batch_loss * vsize
            else:
                val_loss += val_batch_loss

            for h in self.hooks:
                h.on_validation_batch_end(self, val_batch, results)

        # weighted average over batches
        if self.loss_is_normalized:
            val_loss /= n_val

        if self.best_loss > val_loss:
            self.best_loss = val_loss
            torch.save(self._model, self.best_model)

        for h in self.hooks:
            h.on_validation_end(self, val_loss)
Beispiel #3
0
    def calculate(
        self,
        atoms=None,
        properties=['energy', 'forces'],
        system_changes=all_changes,
    ):
        """Calculates the desired properties for the given AtomsBatch.

        Args:
            atoms (AtomsBatch): custom Atoms subclass that contains implementation
                of neighbor lists, batching and so on. Avoids the use of the Dataset
                to calculate using the models created.
            properties (list of str): 'energy', 'forces' or both
            system_changes (default from ase)
        """

        Calculator.calculate(self, atoms, properties, system_changes)

        # run model
        #atomsbatch = AtomsBatch(atoms)
        # batch_to(atomsbatch.get_batch(), self.device)
        batch = batch_to(atoms.get_batch(), self.device)

        # add keys so that the readout function can calculate these properties
        batch['energy'] = []
        if 'forces' in properties:
            batch['energy_grad'] = []

        prediction = self.model(batch)

        # change energy and force to numpy array
        energy = prediction['energy'].detach().cpu().numpy() * (
            1 / const.EV_TO_KCAL_MOL)
        energy_grad = prediction['energy_grad'].detach().cpu().numpy() * (
            1 / const.EV_TO_KCAL_MOL)

        self.results = {
            'energy': energy.reshape(-1),
        }

        if 'forces' in properties:
            self.results['forces'] = -energy_grad.reshape(-1, 3)
Beispiel #4
0
def evaluate(model, loader, device, track, **kwargs):
    """
    Evaluate a model on a dataset.
    Args:
      model (nff.nn.models): original NFF model loaded
      loader (torch.utils.data.DataLoader): data loader
      device (Union[str, int]): device on which you run the model
    Returns:
      all_results (dict): dictionary of results
      all_batches (dict): dictionary of ground truth
    """

    model.eval()
    model.to(device)

    all_results = []
    all_batches = []

    iter_func = get_iter_func(track)

    for batch in iter_func(loader):

        batch = batch_to(batch, device)
        results = fps_and_pred(model, batch, **kwargs)

        all_results.append(batch_detach(results))

        # don't overload memory with unnecessary keys
        reduced_batch = {
            key: val
            for key, val in batch.items() if key not in
            ['bond_idx', 'ji_idx', 'kj_idx', 'nbr_list', 'bonded_nbr_list']
        }
        all_batches.append(batch_detach(reduced_batch))

    all_results = concatenate_dict(*all_results)
    all_batches = concatenate_dict(*all_batches)

    return all_results, all_batches
Beispiel #5
0
    def train(self, device, n_epochs=MAX_EPOCHS):
        """Train the model for the given number of epochs on a specified device.

        Args:
            device (torch.torch.Device): device on which training takes place.
            n_epochs (int): number of training epochs.

        Note: Depending on the `hooks`, training can stop earlier than `n_epochs`.

        """
        self.to(device)

        self._stop = False
        # initialize loss, num_batches, and optimizer grad to 0
        loss = torch.tensor(0.0).to(device)
        num_batches = 0
        self.optimizer.zero_grad()

        for h in self.hooks:
            h.on_train_begin(self)
            if hasattr(h, "mini_batches"):
                h.mini_batches = self.mini_batches

        try:
            for _ in range(n_epochs):
                self._model.train()

                self.epoch += 1

                for h in self.hooks:
                    h.on_epoch_begin(self)

                if self._stop:
                    break

                for j, batch in enumerate(self.train_loader):

                    batch = batch_to(batch, device)

                    for h in self.hooks:
                        h.on_batch_begin(self, batch)

                    results = self._model(batch)
                    loss += self.loss_fn(batch, results)
                    self.step += 1

                    # update the loss self.minibatches number
                    # of times before taking a step
                    num_batches += 1
                    if num_batches == self.mini_batches:
                        loss.backward()
                        self.optimizer.step()

                        for h in self.hooks:
                            h.on_batch_end(self, batch, results, loss)

                        # reset loss, num_batches, and the optimizer grad
                        loss = torch.tensor(0.0).to(device)
                        num_batches = 0
                        self.optimizer.zero_grad()

                    if self._stop:
                        break

                if self.epoch % self.checkpoint_interval == 0:
                    self.store_checkpoint()

                # validation
                if self.epoch % self.validation_interval == 0 or self._stop:
                    self.validate(device)

                for h in self.hooks:
                    h.on_epoch_end(self)

                if self._stop:
                    break

            # Training Ends
            # run hooks & store checkpoint
            for h in self.hooks:
                h.on_train_ends(self)
            self.store_checkpoint()

        except Exception as e:
            for h in self.hooks:
                h.on_train_failed(self)

            raise e
Beispiel #6
0
    def validate(self, device, test=False):
        """Validate the current state of the model using the validation set
        """

        self._model.eval()

        for h in self.hooks:
            h.on_validation_begin(self)

        val_loss = 0.0
        n_val = 0

        for val_batch in self.validation_loader:

            val_batch = batch_to(val_batch, device)

            # append batch_size
            if self.mol_loss_norm:
                vsize = len(val_batch["num_atoms"])

            elif self.loss_is_normalized:
                vsize = val_batch['nxyz'].size(0)

            n_val += vsize

            for h in self.hooks:
                h.on_validation_batch_begin(self)

            results = self.call_model(val_batch, train=False)
            # detach from the graph
            results = batch_to(batch_detach(results), device)

            val_batch_loss = self.loss_fn(
                val_batch, results).data.cpu().numpy()

            if self.loss_is_normalized or self.mol_loss_norm:
                val_loss += val_batch_loss * vsize

            else:
                val_loss += val_batch_loss

            for h in self.hooks:
                h.on_validation_batch_end(self, val_batch, results)

        if test:
            return

        # weighted average over batches
        if self.loss_is_normalized or self.mol_loss_norm:
            val_loss /= n_val

        # if running in parallel, save the validation loss
        # and pick up the losses from the other processes too

        if self.parallel:
            self.save_val_loss(val_loss, n_val)
            val_loss = self.load_val_loss()

        for h in self.hooks:
            # delay this until after we know what the real
            # val loss is (e.g. if it's from a metric)
            if isinstance(h, ReduceLROnPlateauHook):
                continue

            h.on_validation_end(self, val_loss)
            metric_dic = getattr(h, "metric_dic", None)
            if metric_dic is None:
                continue
            if self.metric_as_loss in metric_dic:
                val_loss = metric_dic[self.metric_as_loss]
                if self.metric_objective.lower() == "maximize":
                    val_loss *= -1

        for h in self.hooks:
            if not isinstance(h, ReduceLROnPlateauHook):
                continue
            h.on_validation_end(self, val_loss)

        if self.best_loss > val_loss:
            self.best_loss = val_loss
            self.save_as_best()
Beispiel #7
0
    def train(self, device, n_epochs=MAX_EPOCHS):
        """Train the model for the given number of epochs on a specified
        device.

        Args:
            device (torch.torch.Device): device on which training takes place.
            n_epochs (int): number of training epochs.

        Note: Depending on the `hooks`, training can stop earlier than `n_epochs`.

        """
        self.to(device)
        self._stop = False
        # initialize loss, num_batches, and optimizer grad to 0
        loss = torch.tensor(0.0).to(device)
        num_batches = 0
        self.optimizer.zero_grad()
        self.save_as_best()

        for h in self.hooks:
            h.on_train_begin(self)
            if hasattr(h, "mini_batches"):
                h.mini_batches = self.mini_batches

        try:
            for _ in range(n_epochs):

                self._model.train()
                self.epoch += 1

                for h in self.hooks:
                    h.on_epoch_begin(self)

                if self._stop:
                    break

                for j, batch in self.tqdm_enum(self.train_loader):

                    batch = batch_to(batch, device)

                    for h in self.hooks:
                        h.on_batch_begin(self, batch)

                    results = self.call_model(batch, train=True)
                    mini_loss = self.get_loss(batch, results)
                    self.loss_backward(mini_loss)
                    if not torch.isnan(mini_loss):
                        loss += mini_loss.cpu().detach().to(device)

                    self.step += 1
                    # update the loss self.minibatches number
                    # of times before taking a step
                    num_batches += 1

                    if num_batches == self.mini_batches:
                        loss /= self.nloss
                        num_batches = 0
                        # effective number of batches so far
                        eff_batches = int((j + 1) / self.mini_batches)

                        self.optim_step(batch_num=eff_batches,
                                        device=device)

                        for h in self.hooks:
                            h.on_batch_end(self, batch, results, loss)

                        # reset loss and the optimizer grad

                        loss = torch.tensor(0.0).to(device)
                        self.optimizer.zero_grad()

                    if any((self.batch_stop,
                            self._stop, j == self.epoch_cutoff)):
                        break

                # reset for next epoch

                del mini_loss
                num_batches = 0
                loss = torch.tensor(0.0).to(device)
                self.optimizer.zero_grad()

                # store the checkpoint only if this is the base model,
                # otherwise it will get stored unnecessarily from other
                # gpus, which will cause IO issues

                if (self.epoch % self.checkpoint_interval == 0
                        and self.base):
                    self.store_checkpoint()

                # validation
                if (self.epoch % self.validation_interval == 0 or self._stop):
                    self.validate(device)

                for h in self.hooks:
                    h.on_epoch_end(self)

            # Training Ends
            # run hooks & store checkpoint

            for h in self.hooks:
                h.on_train_ends(self)

            if self.base:
                self.store_checkpoint()

        except Exception as e:
            for h in self.hooks:
                h.on_train_failed(self)

            raise e