Example #1
0
def sparse_bool_mask(x, mask, axis=0):
    # Only necessary if indices may have non-unique elements
    indices = tf.boolean_mask(tf.range(tf.shape(x)[axis]), mask)
    n_indices = tf.size(indices)
    # Get indices for the axis
    idx = x.indices[:, axis]
    # Find where indices match the selection
    eq = tf.equal(tf.expand_dims(idx, 1),
                  tf.cast(indices, tf.int64))  # TODO this has quadratic cost
    # Mask for selected values
    sel = tf.reduce_any(eq, axis=1)
    # Selected values
    values_new = tf.boolean_mask(x.values, sel, axis=0)
    # New index value for selected elements
    n_indices = tf.cast(n_indices, tf.int64)
    idx_new = tf.reduce_sum(tf.cast(eq, tf.int64) * tf.range(n_indices),
                            axis=1)
    idx_new = tf.boolean_mask(idx_new, sel, axis=0)
    # New full indices tensor
    indices_new = tf.boolean_mask(x.indices, sel, axis=0)
    indices_new = tf.concat([
        indices_new[:, :axis],
        tf.expand_dims(idx_new, 1), indices_new[:, axis + 1:]
    ],
                            axis=1)
    # New shape
    shape_new = tf.concat(
        [x.dense_shape[:axis], [n_indices], x.dense_shape[axis + 1:]], axis=0)
    return tf.SparseTensor(indices_new, values_new, shape_new)
Example #2
0
    def call(self, inputs):
        if len(inputs) == 3:
            X, A, I = inputs
            self.data_mode = 'graph'
        else:
            X, A = inputs
            I = tf.zeros(tf.shape(X)[:1], dtype=tf.int32)
            self.data_mode = 'single'
        if K.ndim(I) == 2:
            I = I[:, 0]

        A_is_sparse = K.is_sparse(A)

        # Get mask
        y = K.dot(X, K.l2_normalize(self.kernel))
        N = K.shape(X)[-2]
        indices = ops.top_k(y[:, 0], I, self.ratio, self.top_k_var)
        mask = tf.scatter_nd(tf.expand_dims(indices, 1), tf.ones_like(indices),
                             (N, ))

        # Multiply X and y to make layer differentiable
        features = X * self.gating_op(y)

        axis = 0 if len(K.int_shape(
            A)) == 2 else 1  # Cannot use negative axis in tf.boolean_mask
        # Reduce X
        X_pooled = tf.boolean_mask(features, mask, axis=axis)

        # Compute A^2
        if A_is_sparse:
            A_dense = tf.sparse.to_dense(A)
        else:
            A_dense = A
        A_squared = K.dot(A, A_dense)

        # Reduce A
        A_pooled = tf.boolean_mask(A_squared, mask, axis=axis)
        A_pooled = tf.boolean_mask(A_pooled, mask, axis=axis + 1)
        if A_is_sparse:
            A_pooled = tf.contrib.layers.dense_to_sparse(A_pooled)

        output = [X_pooled, A_pooled]

        # Reduce I
        if self.data_mode == 'graph':
            I_pooled = tf.boolean_mask(I[:, None], mask)[:, 0]
            output.append(I_pooled)

        if self.return_mask:
            output.append(mask)

        return output
Example #3
0
def upsampling_from_mask(inputs):
    X_, A_, I_, M_ = inputs
    S_ = tf.eye(tf.shape(M_)[0])
    S_ = tf.boolean_mask(S_, M_)
    S_t_ = tf.transpose(S_)

    X_out_ = K.dot(S_t_, X_)
    A_out_ = K.dot(K.transpose(K.dot(A_, S_)), S_)
    I_out_ = K.dot(S_t_, K.cast(I_[:, None], tf.float32))[:, 0]
    I_out_ = K.cast(I_out_, tf.int32)
    return [X_out_, A_out_, I_out_]
Example #4
0
def top_k(scores, I, ratio, top_k_var):
    """
    Returns indices to get the top K values in `scores` segment-wise, with
    segments defined by I. K is not fixed, but it is defined as a ratio of the
    number of elements in each segment.
    :param scores: a rank 1 tensor with scores;
    :param I: a rank 1 tensor with segment IDs;
    :param ratio: float, ratio of elements to keep for each segment;
    :param top_k_var: a tf.Variable without shape validation (e.g.,
    `tf.Variable(0.0, validate_shape=False)`);
    :return: a rank 1 tensor containing the indices to get the top K values of
    each segment in `scores`.
    """
    num_nodes = tf.segment_sum(tf.ones_like(I),
                               I)  # Number of nodes in each graph
    cumsum = tf.cumsum(num_nodes)  # Cumulative number of nodes (A, A+B, A+B+C)
    cumsum_start = cumsum - num_nodes  # Start index of each graph
    n_graphs = tf.shape(num_nodes)[0]  # Number of graphs in batch
    max_n_nodes = tf.reduce_max(num_nodes)  # Order of biggest graph in batch
    batch_n_nodes = tf.shape(I)[0]  # Number of overall nodes in batch
    to_keep = tf.ceil(ratio * tf.cast(num_nodes, tf.float32))
    to_keep = tf.cast(to_keep, tf.int32)  # Nodes to keep in each graph

    index = tf.range(batch_n_nodes)
    index = (index - tf.gather(cumsum_start, I)) + (I * max_n_nodes)

    y_min = tf.reduce_min(scores)
    dense_y = tf.ones((n_graphs * max_n_nodes, ))
    dense_y = dense_y * tf.cast(
        y_min - 1, tf.float32
    )  # subtract 1 to ensure that filler values do not get picked
    dense_y = tf.assign(
        top_k_var, dense_y, validate_shape=False
    )  # top_k_var is a variable with unknown shape defined in the elsewhere
    dense_y = tf.scatter_update(dense_y, index, scores)
    dense_y = tf.reshape(dense_y, (n_graphs, max_n_nodes))

    perm = tf.argsort(dense_y, direction='DESCENDING')
    perm = perm + cumsum_start[:, None]
    perm = tf.reshape(perm, (-1, ))

    to_rep = tf.tile(tf.constant([1., 0.]), (n_graphs, ))
    rep_times = tf.reshape(
        tf.concat((to_keep[:, None], (max_n_nodes - to_keep)[:, None]), -1),
        (-1, ))
    mask = tf_repeat_1d(to_rep, rep_times)

    perm = tf.boolean_mask(perm, mask)

    return perm
Example #5
0
def tf_repeat_1d(x, repeats):
    """
    Repeats each value `x[i]` a number of times `repeats[i]`.
    :param x: a rank 1 tensor;
    :param repeats: a rank 1 tensor;
    :return: a rank 1 tensor, of shape `(sum(repeats), )`.
    """
    x = tf.expand_dims(x, 1)
    max_repeats = tf.reduce_max(repeats)
    tile_repeats = [1, max_repeats]
    arr_tiled = tf.tile(x, tile_repeats)
    mask = tf.less(tf.range(max_repeats), tf.expand_dims(repeats, 1))
    result = tf.reshape(tf.boolean_mask(arr_tiled, mask), [-1])
    return result