Exemplo n.º 1
0
def evaluate(model,
             loader,
             loss_fn,
             device,
             return_results=True,
             loss_is_normalized=True,
             submodel=None,
             **kwargs):
    """Evaluate the current state of the model using a given dataloader
    """

    model.eval()
    model.to(device)

    eval_loss = 0.0
    n_eval = 0

    all_results = []
    all_batches = []

    for batch in loader:
        # append batch_size
        batch = batch_to(batch, device)

        vsize = batch['nxyz'].size(0)
        n_eval += vsize

        # e.g. if the result is a sum of results from two models, and you just
        # want the prediction of one of those models
        if submodel is not None:
            results = getattr(model, submodel)(batch)
        else:
            results = model(batch, **kwargs)

        eval_batch_loss = loss_fn(batch, results).data.cpu().numpy()

        if loss_is_normalized:
            eval_loss += eval_batch_loss * vsize
        else:
            eval_loss += eval_batch_loss

        all_results.append(batch_detach(results))
        all_batches.append(batch_detach(batch))

        # del results
        # del batch

    # weighted average over batches
    if loss_is_normalized:
        eval_loss /= n_eval

    if not return_results:
        return {}, {}, eval_loss

    else:
        # this step can be slow,
        all_results = concatenate_dict(*all_results)
        all_batches = concatenate_dict(*all_batches)

        return all_results, all_batches, eval_loss
Exemplo n.º 2
0
    def test_inexistent_list_lists(self):
        a = {'a': [[[1, 2]], [[3, 4]]], 'b': [5, 6]}

        b = {'b': [7, 8]}
        ab = concatenate_dict(a, b)
        expected = {'a': [[[1, 2]], [[3, 4]], None, None], 'b': [5, 6, 7, 8]}
        self.assertEqual(ab, expected)
Exemplo n.º 3
0
    def test_concatenate(self):
        dict_1 = self.dataset[0]
        dict_2 = self.dataset[1:3]

        concat_dict = concatenate_dict(dict_1, dict_2)

        print(concat_dict['energy'])
        print(concat_dict['smiles'])
Exemplo n.º 4
0
 def test_tensors(self):
     d1 = {'a': torch.tensor([1.])}
     d2 = {'a': torch.tensor([2., 3.])}
     dcat = concatenate_dict(d1, d2)
     expected = {
         'a': [
             torch.tensor(1.),
             torch.tensor(2.),
             torch.tensor(3.),
         ]
     }
     self.assertEqual(dcat, expected)
Exemplo n.º 5
0
def evaluate(model, loader, device, track, **kwargs):
    """
    Evaluate a model on a dataset.
    Args:
      model (nff.nn.models): original NFF model loaded
      loader (torch.utils.data.DataLoader): data loader
      device (Union[str, int]): device on which you run the model
    Returns:
      all_results (dict): dictionary of results
      all_batches (dict): dictionary of ground truth
    """

    model.eval()
    model.to(device)

    all_results = []
    all_batches = []

    iter_func = get_iter_func(track)

    for batch in iter_func(loader):

        batch = batch_to(batch, device)
        results = fps_and_pred(model, batch, **kwargs)

        all_results.append(batch_detach(results))

        # don't overload memory with unnecessary keys
        reduced_batch = {
            key: val
            for key, val in batch.items() if key not in
            ['bond_idx', 'ji_idx', 'kj_idx', 'nbr_list', 'bonded_nbr_list']
        }
        all_batches.append(batch_detach(reduced_batch))

    all_results = concatenate_dict(*all_results)
    all_batches = concatenate_dict(*all_batches)

    return all_results, all_batches
Exemplo n.º 6
0
 def test_concat_tensors(self):
     t = {
         'a': torch.tensor(1),
         'b': [torch.tensor([2, 3])],
         'c': torch.tensor([[1, 0], [0, 1]]),
     }
     tt = {
         'a': [torch.tensor(1)] * 2,
         'b': [torch.tensor([2, 3])] * 2,
         'c': [torch.tensor([[1, 0], [0, 1]])] * 2,
     }
     concat = concatenate_dict(t, t)
     for key, val in concat.items():
         for i, j in zip(val, tt[key]):
             self.assertTrue((i == j).all().item())
Exemplo n.º 7
0
    def setUp(self):
        self.quartz = {
            "nxyz":
            np.array(
                [[14.0, -1.19984241582007, 2.07818802527655, 4.59909615202747],
                 [14.0, 1.31404847917993, 2.27599872954824, 2.7594569553608],
                 [14.0, 2.39968483164015, 0.0, 0.919817758694137],
                 [8.0, -1.06646793438585, 3.24694318819338, 0.20609293956337],
                 [8.0, 0.235189576572621, 1.80712683722845, 3.8853713328967],
                 [8.0, 0.831278357813231, 3.65430348422777, 2.04573213623004],
                 [8.0, 3.34516925281323, 0.699883270597028, 5.31282465043663],
                 [8.0, 1.44742296061415, 1.10724356663142, 1.6335462571033],
                 [8.0, 2.74908047157262, 2.54705991759635, 3.47318545376996]]),
            "lattice":
            np.array([[5.02778179, 0.0, 3.07862843796742e-16],
                      [-2.513890895, 4.3541867548248, 3.07862843796742e-16],
                      [0.0, 0.0, 5.51891759]])
        }

        self.qtz_dataset = Dataset(concatenate_dict(*[self.quartz] * 3))
Exemplo n.º 8
0
 def test_concat_list_lists(self):
     dd = concatenate_dict(self.dict_d, self.dict_d)
     self.assertEqual(dd, self.dict_dd)
Exemplo n.º 9
0
 def test_concat_single_dict_lists(self):
     a = concatenate_dict(self.dict_a_list)
     self.assertEqual(a, self.dict_a_list)
Exemplo n.º 10
0
 def test_concat_2(self):
     ac = concatenate_dict(self.dict_a, self.dict_c)
     self.assertEqual(ac, self.dict_ac)
Exemplo n.º 11
0
 def test_concat_1(self):
     ab = concatenate_dict(self.dict_a, self.dict_b)
     self.assertEqual(ab, self.dict_ab)