def test_get_labels_bond_order(self): a = ConnectivityTree(depth=3, use_bond_order=True) X = a.fit_transform([METHANE]) labels = a.get_labels() self.assertEqual(X.shape[1], len(labels)) expected = ('1-C_1_H-1__2-H_1_C-3', '1-H_1_C-1__1-H_1_H-3') self.assertEqual(labels, expected)
def test_get_labels_unknown(self): a = ConnectivityTree(depth=2, add_unknown=True) X = a.fit_transform([METHANE]) labels = a.get_labels() self.assertEqual(X.shape[1], len(labels)) expected = ('0-Root-C-1__1-C-H-4', '0-Root-H-1__1-H-C-1', UNKNOWN) self.assertEqual(labels, expected)
def test_get_labels_coordination(self): a = ConnectivityTree(depth=1, use_coordination=True) X = a.fit_transform([METHANE]) labels = a.get_labels() self.assertEqual(X.shape[1], len(labels)) expected = ('0-Root-C4-1', '0-Root-H1-1') self.assertEqual(labels, expected)
def test_fit_transform(self): a = ConnectivityTree(depth=2) self.assertTrue((a.fit_transform(ALL_DATA) == ALL_ATOM_TREE).all())