Exemplo n.º 1
0
def test_update_atom_array_2(single_atom_datasets, cutoff):
    atom_arrays, adj_arrays = single_atom_datasets
    actual_atom_arrays, actual_label_frequency = wle_update.update_atom_arrays(
        atom_arrays, adj_arrays, cutoff)

    # Note that labels after expansion need not
    # same as the original atom labels.
    # For example, assigning ids accoring to
    # appearance order
    # 0 -> 0, 1 -> 1, 2 -> 2, 5 -> 3, 4 -> 4,
    # results in
    # Atom arrays
    #   train: [[0], [1], [2]]
    #   val:   [[1], [1], [3]]
    #   test:  [[4], [4], [2]]
    # Label Frequency
    #    {'0': 1, '1': 3, '2': 2, '3': 1, '4': 2}
    # This is acceptable.

    train, val, test = actual_atom_arrays
    assert _is_all_same((train[1], val[0], val[1]))
    assert _is_all_same((train[2], test[2]))
    assert _is_all_same((test[0], test[1]))
    assert _is_all_different((train[0], train[1], train[2], val[2], test[0]))

    expect_label_frequency = {'0-': 1, '1-': 3, '2-': 2, '4-': 2, '5-': 1}
    # Equal as a multiset.
    assert (sorted(actual_label_frequency.values()) == sorted(
        expect_label_frequency.values()))
Exemplo n.º 2
0
def apply_wle_for_datasets(datasets, cutoff=0, k=1):
    """
    Apply label Weisfeiler--Lehman Embedding for the tuple of datasets.

    Args:
        datasets: tuple of dataset (usually, train/val/test),
                     each dataset consists of atom_array and
                     adj_array and teach_signal
        cutoff: int, if more than 0, the expanded labels
                   whose freq <= cutoff will be removed.
        k: int, the number of iterations of neighborhood
              aggregation.

    Returns:
        - tuple of dataset (usually, train/val/test),
               each dataest consists of atom_number_array and
               adj_tensor with expanded labels
        - list of all labels, used in the dataset parts.
        - dictionary of label frequencies key:label valeu:frequency count
    """

    atom_arrays, adj_arrays, teach_signals = wle_io.load_dataset_elements(datasets)

    for _ in range(k):
        atom_arrays, labels_frequencies = wle_update.update_atom_arrays(
            atom_arrays, adj_arrays, cutoff)

    datasets_expanded = wle_io.create_datasets(atom_arrays, adj_arrays, teach_signals)
    expanded_labels = list(labels_frequencies.keys())
    return tuple(datasets_expanded), expanded_labels, labels_frequencies
Exemplo n.º 3
0
def test_update_atom_array_with_diffent_sample_sizes(
        different_sample_size_datasets):
    atom_arrays, adj_arrays = different_sample_size_datasets
    actual_atom_arrays, actual_label_frequency = wle_update.update_atom_arrays(
        atom_arrays, adj_arrays, 0)

    all_atoms = sum([list(a.ravel()) for a in actual_atom_arrays], [])
    assert _is_all_same(all_atoms)

    expect_label_frequency = {'0-': 6}
    assert actual_label_frequency == expect_label_frequency
Exemplo n.º 4
0
def test_update_atom_array_with_different_graph_size(
        different_graph_size_datasets):
    atom_arrays, adj_arrays = different_graph_size_datasets
    actual_atom_arrays, actual_label_frequency = wle_update.update_atom_arrays(
        atom_arrays, adj_arrays, 0)

    mols = [d[0] for d in actual_atom_arrays]
    for m in mols:
        assert _is_all_same(m)

    expect_label_frequency = {'0-': 1, '0-0': 2, '0-0.0': 3}
    assert actual_label_frequency == expect_label_frequency
Exemplo n.º 5
0
def apply_cwle_for_datasets(datasets, k=1):
    """
    Apply Concatenated Weisfeiler--Lehman embedding for the tuple of datasets.
    This also applicalbe for the Gated-sum Weisfeiler--Lehman embedding.

    Args:
        datasets: tuple of dataset (usually, train/val/test),
                     each dataset consists of atom_array and
                     adj_array and teach_signal
        k: int, the number of iterations of neighborhood
              aggregation.

    Returns:
        - tuple of dataset (usually, train/val/test),
               each dataest consists of atom_number_array,
               expanded_label_array, and adj_tensor
        - list of all expanded labels, used in the dataset parts.
        - dictionary of label frequencies key:label valeu:frequency count
    """

    if k <= 0:
        raise ValueError('Iterations should be a positive integer. '
                         'Found k={}'.format(k))

    atom_arrays, adj_arrays, teach_signals = wle_io.load_dataset_elements(datasets)

    for i in range(k):
        if i != k - 1:
            atom_arrays, labels_frequencies = wle_update.update_atom_arrays(
                atom_arrays, adj_arrays, 0)
        else:
            wle_arrays, labels_frequencies = wle_update.update_atom_arrays(
                atom_arrays, adj_arrays, 0, False)

    datasets_expanded = wle_io.create_datasets(
        atom_arrays, adj_arrays, teach_signals, wle_arrays)
    expanded_labels = list(labels_frequencies.keys())
    return tuple(datasets_expanded), expanded_labels, labels_frequencies
Exemplo n.º 6
0
def test_update_atom_array_twice(line_graph_datasets):
    atom_arrays, adj_arrays = line_graph_datasets

    for _ in range(2):
        atom_arrays, actual_label_frequency = wle_update.update_atom_arrays(
            atom_arrays, adj_arrays, 0)

    expect_label_frequency = {
        '0-1': 2,
        '1-0.1': 2,
        '1-1.1': 1,
        '2-': 2
    }  # atoms in test and val datasets
    assert actual_label_frequency == expect_label_frequency
Exemplo n.º 7
0
def test_update_atom_array(k3_datasets, cutoff):
    atom_arrays, adj_arrays = k3_datasets
    actual_atom_arrays, actual_label_frequency = wle_update.update_atom_arrays(
        atom_arrays, adj_arrays, cutoff)

    mols = [d[0] for d in actual_atom_arrays]
    for m in mols:
        assert _is_all_same(m)

    # train/val/test atoms must have different labels.
    assert _is_all_different((mols[0][0], mols[1][0], mols[2][0]))

    if cutoff >= 3:
        expect_label_frequency = {'0': 3, '1': 3, '2': 3}
    else:
        expect_label_frequency = {'0-0.0': 3, '1-1.1': 3, '2-2.2': 3}
    assert actual_label_frequency == expect_label_frequency