def evaluate(model, loader, loss_fn, device, return_results=True, loss_is_normalized=True, submodel=None, **kwargs): """Evaluate the current state of the model using a given dataloader """ model.eval() model.to(device) eval_loss = 0.0 n_eval = 0 all_results = [] all_batches = [] for batch in loader: # append batch_size batch = batch_to(batch, device) vsize = batch['nxyz'].size(0) n_eval += vsize # e.g. if the result is a sum of results from two models, and you just # want the prediction of one of those models if submodel is not None: results = getattr(model, submodel)(batch) else: results = model(batch, **kwargs) eval_batch_loss = loss_fn(batch, results).data.cpu().numpy() if loss_is_normalized: eval_loss += eval_batch_loss * vsize else: eval_loss += eval_batch_loss all_results.append(batch_detach(results)) all_batches.append(batch_detach(batch)) # del results # del batch # weighted average over batches if loss_is_normalized: eval_loss /= n_eval if not return_results: return {}, {}, eval_loss else: # this step can be slow, all_results = concatenate_dict(*all_results) all_batches = concatenate_dict(*all_batches) return all_results, all_batches, eval_loss
def test_inexistent_list_lists(self): a = {'a': [[[1, 2]], [[3, 4]]], 'b': [5, 6]} b = {'b': [7, 8]} ab = concatenate_dict(a, b) expected = {'a': [[[1, 2]], [[3, 4]], None, None], 'b': [5, 6, 7, 8]} self.assertEqual(ab, expected)
def test_concatenate(self): dict_1 = self.dataset[0] dict_2 = self.dataset[1:3] concat_dict = concatenate_dict(dict_1, dict_2) print(concat_dict['energy']) print(concat_dict['smiles'])
def test_tensors(self): d1 = {'a': torch.tensor([1.])} d2 = {'a': torch.tensor([2., 3.])} dcat = concatenate_dict(d1, d2) expected = { 'a': [ torch.tensor(1.), torch.tensor(2.), torch.tensor(3.), ] } self.assertEqual(dcat, expected)
def evaluate(model, loader, device, track, **kwargs): """ Evaluate a model on a dataset. Args: model (nff.nn.models): original NFF model loaded loader (torch.utils.data.DataLoader): data loader device (Union[str, int]): device on which you run the model Returns: all_results (dict): dictionary of results all_batches (dict): dictionary of ground truth """ model.eval() model.to(device) all_results = [] all_batches = [] iter_func = get_iter_func(track) for batch in iter_func(loader): batch = batch_to(batch, device) results = fps_and_pred(model, batch, **kwargs) all_results.append(batch_detach(results)) # don't overload memory with unnecessary keys reduced_batch = { key: val for key, val in batch.items() if key not in ['bond_idx', 'ji_idx', 'kj_idx', 'nbr_list', 'bonded_nbr_list'] } all_batches.append(batch_detach(reduced_batch)) all_results = concatenate_dict(*all_results) all_batches = concatenate_dict(*all_batches) return all_results, all_batches
def test_concat_tensors(self): t = { 'a': torch.tensor(1), 'b': [torch.tensor([2, 3])], 'c': torch.tensor([[1, 0], [0, 1]]), } tt = { 'a': [torch.tensor(1)] * 2, 'b': [torch.tensor([2, 3])] * 2, 'c': [torch.tensor([[1, 0], [0, 1]])] * 2, } concat = concatenate_dict(t, t) for key, val in concat.items(): for i, j in zip(val, tt[key]): self.assertTrue((i == j).all().item())
def setUp(self): self.quartz = { "nxyz": np.array( [[14.0, -1.19984241582007, 2.07818802527655, 4.59909615202747], [14.0, 1.31404847917993, 2.27599872954824, 2.7594569553608], [14.0, 2.39968483164015, 0.0, 0.919817758694137], [8.0, -1.06646793438585, 3.24694318819338, 0.20609293956337], [8.0, 0.235189576572621, 1.80712683722845, 3.8853713328967], [8.0, 0.831278357813231, 3.65430348422777, 2.04573213623004], [8.0, 3.34516925281323, 0.699883270597028, 5.31282465043663], [8.0, 1.44742296061415, 1.10724356663142, 1.6335462571033], [8.0, 2.74908047157262, 2.54705991759635, 3.47318545376996]]), "lattice": np.array([[5.02778179, 0.0, 3.07862843796742e-16], [-2.513890895, 4.3541867548248, 3.07862843796742e-16], [0.0, 0.0, 5.51891759]]) } self.qtz_dataset = Dataset(concatenate_dict(*[self.quartz] * 3))
def test_concat_list_lists(self): dd = concatenate_dict(self.dict_d, self.dict_d) self.assertEqual(dd, self.dict_dd)
def test_concat_single_dict_lists(self): a = concatenate_dict(self.dict_a_list) self.assertEqual(a, self.dict_a_list)
def test_concat_2(self): ac = concatenate_dict(self.dict_a, self.dict_c) self.assertEqual(ac, self.dict_ac)
def test_concat_1(self): ab = concatenate_dict(self.dict_a, self.dict_b) self.assertEqual(ab, self.dict_ab)