def make_flat_target_matrix(full_relation, rel_ids, pos_heads, pos_tails, neg_heads, neg_tails, device):
    full_heads, full_tails = np.array([], dtype=np.int32), np.array([], dtype=np.int32)
    for rel_id in rel_ids:
        full_heads = np.concatenate((full_heads, pos_heads[rel_id]))
        full_heads = np.concatenate((full_heads, neg_heads[rel_id]))
        full_tails = np.concatenate((full_tails, pos_tails[rel_id]))
        full_tails = np.concatenate((full_tails, neg_tails[rel_id]))
    n_rels = len(rel_ids)
    indices = torch.LongTensor(np.vstack((full_heads, full_tails)))
    values = torch.zeros((indices.shape[1], n_rels))
    shape = (full_relation.entities[0].n_instances,
             full_relation.entities[1].n_instances, n_rels)
    full_matrix = SparseMatrix(indices=indices, values=values, shape=shape)
    full_matrix = full_matrix.to(device).coalesce_()
    matrix_out = SparseMatrix.from_other_sparse_matrix(full_matrix, 0)

    for rel_id in rel_ids:
        rel_matrix = make_target_matrix(full_relation,  pos_heads[rel_id],
                                        pos_tails[rel_id], neg_heads[rel_id],
                                        neg_tails[rel_id], device)

        rel_matrix_full = SparseMatrix.from_other_sparse_matrix(full_matrix, 1) + rel_matrix
        matrix_out.values = torch.cat([matrix_out.values, rel_matrix_full.values], 1)
        matrix_out.n_channels += 1
    return matrix_out
Beispiel #2
0
def combine_matrices_flat(full_relation, a_pos_heads, a_pos_tails, a_neg_heads,
                          a_neg_tails, ids, b_matrix, device):
    '''
    inputs:
        a_heads: a dict of ID : head indices
        a_tails: a dict of ID : tail indices
        ids: IDs with which to access the indices of A
        b_matrix: a matrix whose indices we want to include in output

    returns:
        out_matrix: matrix with indices & values of A as well as indices of B
        valid_masks: a dict of id:indices that correspond to the indices  for
             each of the relations in A
    '''
    full_heads, full_tails = np.array([],
                                      dtype=np.int32), np.array([],
                                                                dtype=np.int32)
    for rel_id in ids:
        full_heads = np.concatenate((full_heads, a_pos_heads[rel_id]))
        full_heads = np.concatenate((full_heads, a_neg_heads[rel_id]))
        full_tails = np.concatenate((full_tails, a_pos_tails[rel_id]))
        full_tails = np.concatenate((full_tails, a_neg_tails[rel_id]))
    indices = torch.LongTensor(np.vstack((full_heads, full_tails)))
    values = torch.zeros((indices.shape[1], 1))
    shape = (full_relation.entities[0].n_instances,
             full_relation.entities[1].n_instances, 1)
    full_a_matrix = SparseMatrix(indices=indices, values=values, shape=shape)
    full_a_matrix = full_a_matrix.to(device).coalesce_()

    b_idx_matrix = SparseMatrix.from_other_sparse_matrix(b_matrix, 1)
    b_idx_matrix.values += 1

    out_idx_matrix = b_idx_matrix + full_a_matrix
    out_matrix = SparseMatrix.from_other_sparse_matrix(out_idx_matrix, 0)

    for rel_id in ids:
        rel_matrix = make_target_matrix(full_relation, a_pos_heads[rel_id],
                                        a_pos_tails[rel_id],
                                        a_neg_heads[rel_id],
                                        a_neg_tails[rel_id], device)

        rel_full_matrix = SparseMatrix.from_other_sparse_matrix(
            out_idx_matrix, 1) + rel_matrix
        out_matrix.values = torch.cat(
            [out_matrix.values, rel_full_matrix.values], 1)
        out_matrix.n_channels += 1

        rel_idx_matrix = SparseMatrix.from_other_sparse_matrix(rel_matrix, 1)
        rel_idx_matrix.values += 1
        rel_idx_full_matrix = SparseMatrix.from_other_sparse_matrix(
            out_idx_matrix, 1) + rel_idx_matrix
        out_idx_matrix.values = torch.cat(
            [out_idx_matrix.values, rel_idx_full_matrix.values], 1)
        out_idx_matrix.n_channels += 1

    masks = {}
    for channel_i, rel_id in enumerate(ids):
        masks[rel_id] = out_idx_matrix.values[:, channel_i +
                                              1].nonzero().squeeze()
    return out_matrix, masks
def make_target_matrix(relation, pos_head, pos_tail, neg_head, neg_tail, device):
    n_pos = pos_head.shape[0]
    pos_indices = np.vstack((pos_head, pos_tail))
    pos_values = np.ones((n_pos, 1))
    n_neg = neg_head.shape[0]
    neg_indices = np.vstack((neg_head, neg_tail))
    neg_values = np.zeros((n_neg, 1))
    indices = torch.LongTensor(np.concatenate((pos_indices, neg_indices), 1))
    values = torch.FloatTensor(np.concatenate((pos_values, neg_values), 0))
    shape = (relation.entities[0].n_instances,
             relation.entities[1].n_instances, 1)
    data_target = SparseMatrix(indices=indices, values=values, shape=shape)
    data_target = data_target.to(device).coalesce_()

    return data_target