Esempio n. 1
0
def generate_ensembled_predictions(data_root: str, predictions_path: str,
                                   split: str) -> losses.Predictions:
    """Ensemble checkpoints from all WIDs in XID and generates submission file."""

    array_dict = data_utils.get_arrays(data_root=data_root,
                                       return_pca_embeddings=False,
                                       return_adjacencies=False)

    # Load all valid and test predictions.
    node_idx_to_logits_list = load_predictions(predictions_path, split)

    # Assert that the indices loaded are as expected.
    expected_idx = array_dict[f'{split}_indices']
    idx_found = np.array(list(node_idx_to_logits_list.keys()))
    assert np.all(np.sort(idx_found) == expected_idx)

    if split == 'valid':
        true_labels = array_dict['paper_label'][expected_idx.astype(np.int32)]
    else:
        # Don't know the test labels.
        true_labels = np.full(expected_idx.shape, np.nan)

    # Ensemble together all predictions.
    return ensemble_predictions(node_idx_to_logits_list, true_labels,
                                expected_idx)
def _read_adjacency_indices():
    # Get adjacencies.
    return data_utils.get_arrays(
        data_root=FLAGS.data_root,
        use_fused_node_labels=False,
        use_fused_node_adjacencies=False,
        return_pca_embeddings=False,
    )
Esempio n. 3
0
def main(argv):
    del argv
    array_dict = data_utils.get_arrays(data_root=_DATA_ROOT.value,
                                       return_pca_embeddings=False,
                                       return_adjacencies=False)

    os.makedirs(_OUTPUT_DIR.value, exist_ok=True)
    data_utils.generate_k_fold_splits(train_idx=array_dict['train_indices'],
                                      valid_idx=array_dict['valid_indices'],
                                      output_path=_OUTPUT_DIR.value,
                                      num_splits=data_utils.NUM_K_FOLD_SPLITS)
Esempio n. 4
0
def build_dataset_iterator(
    data_root: str,
    split: str,
    dynamic_batch_size_config: config_dict.ConfigDict,
    online_subsampling_kwargs: dict,  # pylint: disable=g-bare-generic
    debug: bool = False,
    is_training: bool = True,
    k_fold_split_id: Optional[int] = None,
    ratio_unlabeled_data_to_labeled_data: float = 0.0,
    use_all_labels_when_not_training: bool = False,
    use_dummy_adjacencies: bool = False,
):
    """Returns an iterator over Batches from the dataset."""

    if split == 'test':
        use_all_labels_when_not_training = True

    if not is_training:
        ratio_unlabeled_data_to_labeled_data = 0.0

    # Load the master data arrays.
    with LOADING_RAW_ARRAYS_LOCK:
        array_dict = data_utils.get_arrays(
            data_root,
            k_fold_split_id=k_fold_split_id,
            use_dummy_adjacencies=use_dummy_adjacencies)

    node_labels = array_dict['paper_label'].reshape(-1)
    train_indices = array_dict['train_indices'].astype(np.int32)
    is_train_index = np.zeros(node_labels.shape[0], dtype=np.int32)
    is_train_index[train_indices] = 1
    valid_indices = array_dict['valid_indices'].astype(np.int32)
    is_valid_index = np.zeros(node_labels.shape[0], dtype=np.int32)
    is_valid_index[valid_indices] = 1
    is_train_or_valid_index = is_train_index + is_valid_index

    def sstable_to_intermediate_graph(graph):
        indices = tf.cast(graph.nodes['index'], tf.int32)
        first_index = indices[..., 0]

        # Add an additional absolute index, but adding offsets to authors, and
        # institution indices.
        absolute_index = graph.nodes['index']
        is_author = graph.nodes['type'] == 1
        absolute_index = tf.where(is_author,
                                  absolute_index + data_utils.NUM_PAPERS,
                                  absolute_index)
        is_institution = graph.nodes['type'] == 2
        absolute_index = tf.where(
            is_institution,
            absolute_index + data_utils.NUM_PAPERS + data_utils.NUM_AUTHORS,
            absolute_index)

        is_same_as_central_node = tf.math.equal(indices, first_index)
        input_nodes = graph.nodes
        graph = graph._replace(
            nodes={
                'one_hot_type':
                tf.one_hot(tf.cast(input_nodes['type'], tf.int32), 3),
                'one_hot_depth':
                tf.one_hot(tf.cast(input_nodes['depth'], tf.int32),
                           _MAX_DEPTH_IN_SUBGRAPH),
                'year':
                tf.expand_dims(input_nodes['year'], axis=-1),
                'label':
                tf.one_hot(tf.cast(input_nodes['label'], tf.int32),
                           NUM_CLASSES),
                'is_same_as_central_node':
                is_same_as_central_node,
                # Only first node in graph has a valid label.
                'is_central_node':
                tf.one_hot(0,
                           tf.shape(input_nodes['label'])[0]),
                'index':
                input_nodes['index'],
                'absolute_index':
                absolute_index,
            },
            globals=tf.expand_dims(graph.globals, axis=-1),
        )

        return graph

    ds = data_utils.get_graph_subsampling_dataset(
        split,
        array_dict,
        shuffle_indices=is_training,
        ratio_unlabeled_data_to_labeled_data=
        ratio_unlabeled_data_to_labeled_data,
        max_nodes=dynamic_batch_size_config.n_node - 1,  # Keep space for pads.
        max_edges=dynamic_batch_size_config.n_edge,
        **online_subsampling_kwargs)
    if debug:
        ds = ds.take(50)
    ds = ds.map(sstable_to_intermediate_graph,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if is_training:
        ds = ds.shard(jax.process_count(), jax.process_index())
        ds = ds.shuffle(buffer_size=1 if debug else 128)
        ds = ds.repeat()
    ds = ds.prefetch(1 if debug else tf.data.experimental.AUTOTUNE)
    np_ds = iter(tfds.as_numpy(ds))
    batched_np_ds = batching_utils.dynamically_batch(
        np_ds,
        **dynamic_batch_size_config,
    )

    def intermediate_graph_to_batch(graph):
        central_node_mask = graph.nodes['is_central_node']
        label = graph.nodes['label']
        node_indices = graph.nodes['index']
        absolute_indices = graph.nodes['absolute_index']

        ### Construct label as a feature for non-central nodes.
        # First do a lookup with node indices, with a np.minimum to ensure we do not
        # index out of bounds due to num_authors being larger than num_papers.
        is_same_as_central_node = graph.nodes['is_same_as_central_node']
        capped_indices = np.minimum(node_indices, node_labels.shape[0] - 1)
        label_as_feature = node_labels[capped_indices]
        # Nodes which are not in train set should get `num_classes` label.
        # Nodes in test set or non-arXiv nodes have -1 or nan labels.

        # Mask out invalid labels and non-papers.
        use_label_as_feature = np.logical_and(
            label_as_feature >= 0, graph.nodes['one_hot_type'][..., 0])
        if split == 'train' or not use_all_labels_when_not_training:
            # Mask out validation papers and non-arxiv papers who
            # got labels from fusing with arxiv papers.
            use_label_as_feature = np.logical_and(
                is_train_index[capped_indices], use_label_as_feature)
        label_as_feature = np.where(use_label_as_feature, label_as_feature,
                                    NUM_CLASSES)
        # Mask out central node label in case it appears again.
        label_as_feature = np.where(is_same_as_central_node, NUM_CLASSES,
                                    label_as_feature)
        # Nodes which are not papers get `NUM_CLASSES+1` label.
        label_as_feature = np.where(graph.nodes['one_hot_type'][..., 0],
                                    label_as_feature, NUM_CLASSES + 1)

        nodes = {
            'label_as_feature':
            label_as_feature,
            'year':
            graph.nodes['year'],
            'bitstring_year':
            _get_bitstring_year_representation(graph.nodes['year']),
            'one_hot_type':
            graph.nodes['one_hot_type'],
            'one_hot_depth':
            graph.nodes['one_hot_depth'],
        }

        graph = graph._replace(
            nodes=nodes,
            globals={},
        )
        is_train_or_valid_node = np.logical_and(
            is_train_or_valid_index[capped_indices],
            graph.nodes['one_hot_type'][..., 0])
        if is_training:
            label_mask = np.logical_and(central_node_mask,
                                        is_train_or_valid_node)
        else:
            # `label_mask` is used to index into valid central nodes by prediction
            # calculator. Since that computation is only done when not training, and
            # at that time we are guaranteed all central nodes have valid labels,
            # we just set label_mask = central_node_mask when not training.
            label_mask = central_node_mask
        batch = Batch(graph=graph,
                      node_labels=label,
                      central_node_mask=central_node_mask,
                      label_mask=label_mask,
                      node_indices=node_indices,
                      absolute_node_indices=absolute_indices)

        # Transform integers into one-hots.
        batch = _add_one_hot_features_to_batch(batch)

        # Gather PCA features.
        return _add_embeddings_to_batch(batch, array_dict['bert_pca_129'])

    batch_list = []
    for batch in batched_np_ds:
        with jax.profiler.StepTraceAnnotation('batch_postprocessing'):
            batch = intermediate_graph_to_batch(batch)
        if is_training:
            batch_list.append(batch)
            if len(batch_list) == jax.local_device_count():
                yield jax.device_put_sharded(batch_list, jax.local_devices())
                batch_list = []
        else:
            yield batch