def test_schnet_preprocessor_with_tox21():
    preprocessor = SchNetPreprocessor()

    dataset = SDFFileParser(preprocessor, postprocess_label=None).parse(
        get_tox21_filepath('train'))['dataset']

    index = numpy.random.choice(len(dataset), None)
    atoms, adjs = dataset[index]

    assert atoms.ndim == 1  # (atom, )
    assert atoms.dtype == numpy.int32
    # (atom from, atom to)
    assert adjs.ndim == 2
    assert adjs.dtype == numpy.float32
def test_schnet_preprocessor_default():
    preprocessor = SchNetPreprocessor()

    dataset = SmilesParser(preprocessor).parse(
        ['C#N', 'Cc1cnc(C=O)n1C', 'c1ccccc1'])['dataset']

    index = numpy.random.choice(len(dataset), None)
    atoms, adjs = dataset[index]

    assert atoms.ndim == 1  # (atom, )
    assert atoms.dtype == numpy.int32
    # (atom from, atom to)
    assert adjs.ndim == 2
    assert adjs.dtype == numpy.float32
def test_schnet_preprocessor_assert_raises():
    with pytest.raises(ValueError):
        pp = SchNetPreprocessor(max_atoms=3, out_size=2)  # NOQA
def pp():
    return SchNetPreprocessor()