Ejemplo n.º 1
0
 def _create_generator(self, *args, **kwargs) -> \
         Union[GraphBatchDistanceConvert, GraphBatchGenerator]:
     if hasattr(self.graph_converter, 'bond_converter'):
         kwargs.update(
             {'distance_converter': self.graph_converter.bond_converter})
         return GraphBatchDistanceConvert(*args, **kwargs)
     return GraphBatchGenerator(*args, **kwargs)
Ejemplo n.º 2
0
 def test_graph_batch_distance_converter(self):
     feature = [
         np.random.normal(size=(3, 4)),
         np.random.normal(size=(2, 4))
     ]
     bond = [np.random.normal(size=(2, )), np.random.normal(size=(1, ))]
     glob_features = [
         np.random.normal(size=(1, 2)),
         np.random.normal(size=(1, 2))
     ]
     index1 = [np.array([0, 1]), np.array([0])]
     index2 = [np.array([1, 2]), np.array([1])]
     targets = np.random.normal(size=(2, 1))
     centers = np.linspace(0, 5, 20)
     width = 0.5
     gen = GraphBatchDistanceConvert(feature,
                                     bond,
                                     glob_features,
                                     index1,
                                     index2,
                                     targets,
                                     batch_size=2,
                                     distance_converter=GaussianDistance(
                                         centers, width))
     data = gen[0]
     self.assertListEqual(list(data[0][0].shape), [1, 5, 4])
     self.assertListEqual(list(data[0][1].shape), [1, 3, 20])
     self.assertListEqual(list(data[0][2].shape), [1, 2, 2])
     self.assertListEqual(list(data[0][3].shape), [1, 3])
     self.assertListEqual(list(data[0][4].shape), [1, 3])
     self.assertListEqual(list(data[1].shape), [1, 2, 1])
Ejemplo n.º 3
0
 def _create_generator(self, *args, **kwargs):
     if hasattr(self.graph_convertor, 'bond_convertor'):
         kwargs.update(
             {'distance_convertor': self.graph_convertor.bond_convertor})
         return GraphBatchDistanceConvert(*args, **kwargs)
     else:
         return GraphBatchGenerator(*args, **kwargs)
Ejemplo n.º 4
0
def evaluate(test_graphs, test_targets):
    """
    Evaluate the test errors using test_graphs and test_targets

    Args:
        test_graphs (list): list of graphs
        test_targets (list): list of target properties

    Returns:
        mean absolute errors
    """
    test_data = model.graph_converter.get_flat_data(test_graphs, test_targets)
    gen = GraphBatchDistanceConvert(
        *test_data,
        distance_converter=model.graph_converter.bond_converter,
        batch_size=128)
    preds = []
    trues = []
    for i in range(len(gen)):
        d = gen[i]
        preds.extend(model.predict(d[0]).ravel().tolist())
        trues.extend(d[1].ravel().tolist())
    return np.mean(np.abs(np.array(preds) - np.array(trues)))
Ejemplo n.º 5
0
 def test_graph_to_inputs_and_class_generator(self):
     graphs = [index_rep_from_structure(i) for i in self.structures] * 4
     mp_ids = ['mp-19017', 'mp-2998'] * 5
     targets = [0.1, 0.2] * 5
     out = graph_to_inputs(mp_ids, graphs, targets)
     self.assertEqual(len(out), 7)
     gen = GraphBatchDistanceConvert(*out[:-1],
                                     batch_size=2,
                                     distance_converter=GaussianDistance())
     data = gen[0]
     x = data[0]
     y = data[1]
     # only one graph, therefore the batch dimension is 1
     self.assertListEqual([i.shape[0] for i in x], [1] * len(x))
     # atom is 1*N where N is the total number of atoms
     self.assertEqual(len(x[0].shape), 2)
     # bond is 1*M*G
     self.assertEqual(len(x[1].shape), 3)
     # global is 1*U*K
     self.assertEqual(len(x[2].shape), 3)
     self.assertListEqual([len(x[i].shape) for i in range(3, 7)], [2] * 4)
     # target is 1*2(crystal)*1(target dimension)
     self.assertListEqual(list(y.shape), [1, 2, 1])
Ejemplo n.º 6
0
 def _create_generator(self, *args, **kwargs):
     if self.distance_convertor is not None:
         kwargs.update({'distance_convertor': self.distance_convertor})
         return GraphBatchDistanceConvert(*args, **kwargs)
     else:
         return GraphBatchGenerator(*args, **kwargs)