Ejemplo n.º 1
0
def v_scheme(tree, depths, directions, neighbours):
    """Children of the parent's colleagues that are separated from the box"""
    D = tree.children.ndim - 1
    nonzero_directions = (directions != 0).any(-1)
    descents = sets.flat_cartesian_product(
        torch.tensor([-1, +1], device=tree.id.device), D)

    # The v list is many times bigger than the other lists, so we'll loop rather than
    # vectorize to preserve memory.
    result = []
    for d in nonzero_directions.nonzero().squeeze(1):
        colleagues = neighbours[tree.parents, d]
        for friend_descent in descents:
            friends = child_boxes(tree, colleagues, friend_descent)
            for own_descent in descents:
                offset = (-own_descent + 4 * directions[d] +
                          friend_descent) / 2
                if (offset.abs() <= 1).all(-1):
                    continue

                for depth in torch.arange(depths.domain,
                                          device=tree.id.device):
                    s = depths.slice(depth)
                    mask = (tree.descent[s] == own_descent
                            ).all(-1) & ~tree.terminal[colleagues[s]] & (
                                colleagues[s] >= 0)
                    result.append(
                        arrdict.arrdict(boxes=tree.id[s][mask],
                                        friends=friends[s][mask],
                                        offset=offset,
                                        depth=depth))

    return result
Ejemplo n.º 2
0
def interaction_scheme(tree, depths):
    """Constructs the datastructures needed to calculate the interactions between boxes.
    
    See Carrier, Greengard & Rokhlin's 1988 paper for a description of u, v, w, and x 
    interactions:

    https://pdfs.semanticscholar.org/97f0/d2a31d818ede922c9a59dc17f710642332ca.pdf

    §3.2, Notation, is what you're after, along with Fig 5.

    The datastructures are pretty heterogeneous because, well, performance. They're set
    up so the data needed can be got at fast without blowing up the memory budget.
    """
    D = tree.children.ndim - 1
    directions = sets.flat_cartesian_product(
        torch.tensor([-1, 0, +1], device=tree.id.device), D)
    neighbours = torch.stack(
        [neighbour_boxes(tree, tree.id, d) for d in directions], -1)

    w = w_pairs(tree, directions, neighbours)

    return arrdict.arrdict(w=ragged.from_pairs(w, len(tree.id), len(tree.id)),
                           x=ragged.from_pairs(w.flip((1, )), len(tree.id),
                                               len(tree.id)),
                           u=u_scheme(tree, neighbours),
                           v=v_scheme(tree, depths, directions, neighbours))
Ejemplo n.º 3
0
def random_problem(S=3, T=5, D=2, device='cuda'):
    prob = arrdict.arrdict(
        sources=np.random.uniform(-1., +1., (S, D)),
        charges=np.random.uniform(.1, 1., (S, )),
        targets=np.random.uniform(
            -1., +1.,
            (T, D))).map(lambda t: torch.as_tensor(t).float().to(device))
    prob['kernel'] = quad_kernel
    return prob
Ejemplo n.º 4
0
def interaction_scheme(tree, depths):
    D = tree.children.ndim - 1
    directions = sets.flat_cartesian_product(
        torch.tensor([-1, 0, +1], device=tree.id.device), D)
    neighbours = torch.stack(
        [neighbour_boxes(tree, tree.id, d) for d in directions], -1)

    w = w_pairs(tree, directions, neighbours)

    return arrdict.arrdict(w=ragged.from_pairs(w, len(tree.id), len(tree.id)),
                           x=ragged.from_pairs(w.flip((1, )), len(tree.id),
                                               len(tree.id)),
                           u=u_scheme(tree, neighbours),
                           v=v_scheme(tree, depths, directions, neighbours))
Ejemplo n.º 5
0
def simulate(n=10e3, T=40, device='cpu'):
    print(
        f'This demo will be for {int(n)} agents and {T} timesteps on device "{device}".'
    )
    print(
        'The default values are fairly small, so as not to frustrate anyone with out-of-memory errors. Pass larger ones if you want.'
    )
    print(
        'Pass device="cuda" to run on the GPU. On Colab, you may need to enable it first under Edit->Notebook Settings.'
    )

    print('Generating a population...')
    pop = population.points(n=n)

    # Phrase it as an n-body problem and stick it on the GPU
    prob = arrdict.arrdict(
        sources=pop, targets=pop, charges=np.zeros(
            len(pop))).map(lambda t: torch.as_tensor(t).float().to(device))

    # Create our risk kernel. Usually this'd be hard coded, but we want the
    # demo to be interesting for a range of population densities, so it needs
    # to be variable.
    risk_kernel = scale_risk_kernel(n)

    # Wrap the risk kernel so it can be fed into the solver
    prob['kernel'] = wrap(risk_kernel)

    print('Presolving...')
    presoln = presolve(prob)

    # Set patient zero - the infection status is independent of the presolve!
    presoln.scaled.charges[0] = 1.

    print('Running main loop...')
    infected = [presoln.scaled.charges.cpu().numpy()]
    for t in tqdm(range(T)):
        # Evaluate the total risk over all n^2 pairs of points
        log_nonrisk = evaluate(**presoln)
        risk = 1 - torch.exp(log_nonrisk)

        # Figure out which people got infected
        rands = torch.rand_like(risk)
        presoln.scaled.charges = ((rands < risk) |
                                  (0 < presoln.scaled.charges)).float()

        # Add the last step's result to the output
        infected.append(presoln.scaled.charges.cpu().numpy())

    return infected, prob.targets.cpu().numpy()
Ejemplo n.º 6
0
def interaction_scheme(tree, depths):
    """Returns the datastructures needed to calculate the :ref:`interactions <presolve>` between boxes.
    
    The datastructures are pretty heterogeneous because, well, performance. They're set
    up so the data needed can be got at fast without blowing up the memory budget.

    :param tree: a :ref:`tree <presolve>`.
    :param depths: the :ref:`depths <presolve>` to go with the tree.
    :return: a :ref:`scheme <presolve>`.
    """
    D = tree.children.ndim - 1
    directions = sets.flat_cartesian_product(
        torch.tensor([-1, 0, +1], device=tree.id.device), D)
    neighbours = torch.stack(
        [neighbour_boxes(tree, tree.id, d) for d in directions], -1)

    w = w_pairs(tree, directions, neighbours)

    return arrdict.arrdict(w=ragged.from_pairs(w, len(tree.id), len(tree.id)),
                           x=ragged.from_pairs(w.flip((1, )), len(tree.id),
                                               len(tree.id)),
                           u=u_scheme(tree, neighbours),
                           v=v_scheme(tree, depths, directions, neighbours))
Ejemplo n.º 7
0
def v_w_problem():
    return arrdict.arrdict(sources=torch.tensor([[-.25, +.75], [-.75, +.75]]),
                           charges=torch.tensor([1., 1.]),
                           targets=torch.tensor([[+.25, +.75], [+.75, +.75]]),
                           scale=torch.tensor([1., 1.])).cuda()
Ejemplo n.º 8
0
def orthantree(scaled, capacity=8):
    #TODO: Well this is a travesty of incomprehensibility. Verify it then explain yourself.
    D = scaled.sources.shape[1]

    points = torch.cat([scaled.sources, scaled.targets])
    indices = points.new_zeros((len(points), ), dtype=torch.long)

    tree = arrdict.arrdict(parents=indices.new_full((1, ), -1),
                           depths=indices.new_zeros((1, )),
                           centers=points.new_zeros((1, D)),
                           terminal=indices.new_ones((1, ), dtype=torch.bool),
                           children=indices.new_full((1, ) + (2, ) * D, -1),
                           descent=indices.new_zeros((1, D)))

    bases = 2**torch.flip(torch.arange(D, device=indices.device), (0, ))
    subscript_offsets = sets.cartesian_product(
        torch.tensor([0, 1], device=indices.device), D)
    center_offsets = sets.cartesian_product(
        torch.tensor([-1, +1], device=indices.device), D)

    depthcounts = [torch.as_tensor([1], device=indices.device)]

    depth = 0
    while True:
        used, used_inv = torch.unique(indices, return_inverse=True)
        source_idxs, target_idxs = indices[:len(scaled.sources)], indices[
            -len(scaled.targets):]
        tree.terminal[used] = underoccupied(source_idxs, target_idxs,
                                            tree.terminal, capacity)[used]

        used_is_active = ~tree.terminal[used]
        point_is_active = used_is_active[used_inv]
        if not point_is_active.any():
            break

        depth += 1

        active = used[used_is_active]
        active_inv = (used_is_active.cumsum(0) -
                      used_is_active.long())[used_inv[point_is_active]]
        first_child = len(tree.parents) + 2**D * torch.arange(
            len(active), device=active.device)
        point_offset = (
            (points[point_is_active] >= tree.centers[active][active_inv]) *
            bases).sum(-1)
        child_box = first_child[active_inv] + point_offset
        indices[point_is_active] = child_box

        trailing_ones = (slice(None), ) + (None, ) * D
        tree.children[active] = first_child[trailing_ones] + (
            subscript_offsets * bases).sum(-1)

        centers = tree.centers[active][trailing_ones] + center_offsets.float(
        ) / 2**depth
        descent = center_offsets[None].expand_as(centers)

        n_children = len(active) * 2**D
        children = arrdict.arrdict(parents=active.repeat_interleave(2**D),
                                   depths=tree.depths.new_full((n_children, ),
                                                               depth),
                                   centers=centers.reshape(-1, D),
                                   descent=descent.reshape(-1, D),
                                   terminal=tree.terminal.new_ones(
                                       (n_children, )),
                                   children=tree.children.new_full(
                                       (n_children, ) + (2, ) * D, -1))
        tree = arrdict.cat([tree, children])

        depthcounts.append(n_children)

    tree['id'] = torch.arange(len(tree.parents), device=points.device)

    indices = arrdict.arrdict(sources=indices[:len(scaled.sources)],
                              targets=indices[-len(scaled.targets):])

    depths = ragged.Ragged(torch.arange(len(tree.id), device=points.device),
                           torch.as_tensor(depthcounts, device=points.device))

    return tree, indices, depths
Ejemplo n.º 9
0
def orthantree(scaled, capacity=8):
    """Constructs a :ref:`tree <presolve>` for the given :func:`~pybbfmm.scale`'d problem.

    This is a bit of a mess of a function, but long story short it starts with all the sources allocated to the root
    and repeatedly subdivides overfull boxes, constructing the various tree tensors as it goes.

    :param scaled: :func:`~pybbfmm.scale`'d problem.
    :param capacity: the max number of sources or targets per box.
    :return: A :ref:`tree <presolve>`.
    """
    D = scaled.sources.shape[1]

    points = torch.cat([scaled.sources, scaled.targets])
    indices = points.new_zeros((len(points), ), dtype=torch.long)

    tree = arrdict.arrdict(parents=indices.new_full((1, ), -1),
                           depths=indices.new_zeros((1, )),
                           centers=points.new_zeros((1, D)),
                           terminal=indices.new_ones((1, ), dtype=torch.bool),
                           children=indices.new_full((1, ) + (2, ) * D, -1),
                           descent=indices.new_zeros((1, D)))

    bases = 2**torch.flip(torch.arange(D, device=indices.device), (0, ))
    subscript_offsets = sets.cartesian_product(
        torch.tensor([0, 1], device=indices.device), D)
    center_offsets = sets.cartesian_product(
        torch.tensor([-1, +1], device=indices.device), D)

    depthcounts = [torch.as_tensor([1], device=indices.device)]

    depth = 0
    while True:
        used, used_inv = torch.unique(indices, return_inverse=True)
        source_idxs, target_idxs = indices[:len(scaled.sources)], indices[
            -len(scaled.targets):]
        tree.terminal[used] = underoccupied(source_idxs, target_idxs,
                                            tree.terminal, capacity)[used]

        used_is_active = ~tree.terminal[used]
        point_is_active = used_is_active[used_inv]
        if not point_is_active.any():
            break

        depth += 1

        active = used[used_is_active]
        active_inv = (used_is_active.cumsum(0) -
                      used_is_active.long())[used_inv[point_is_active]]
        first_child = len(tree.parents) + 2**D * torch.arange(
            len(active), device=active.device)
        point_offset = (
            (points[point_is_active] >= tree.centers[active][active_inv]) *
            bases).sum(-1)
        child_box = first_child[active_inv] + point_offset
        indices[point_is_active] = child_box

        trailing_ones = (slice(None), ) + (None, ) * D
        tree.children[active] = first_child[trailing_ones] + (
            subscript_offsets * bases).sum(-1)

        centers = tree.centers[active][trailing_ones] + center_offsets.float(
        ) / 2**depth
        descent = center_offsets[None].expand_as(centers)

        n_children = len(active) * 2**D
        children = arrdict.arrdict(parents=active.repeat_interleave(2**D),
                                   depths=tree.depths.new_full((n_children, ),
                                                               depth),
                                   centers=centers.reshape(-1, D),
                                   descent=descent.reshape(-1, D),
                                   terminal=tree.terminal.new_ones(
                                       (n_children, )),
                                   children=tree.children.new_full(
                                       (n_children, ) + (2, ) * D, -1))
        tree = arrdict.cat([tree, children])

        depthcounts.append(n_children)

    tree['id'] = torch.arange(len(tree.parents), device=points.device)

    indices = arrdict.arrdict(sources=indices[:len(scaled.sources)],
                              targets=indices[-len(scaled.targets):])

    depths = ragged.Ragged(torch.arange(len(tree.id), device=points.device),
                           torch.as_tensor(depthcounts, device=points.device))

    return tree, indices, depths
Ejemplo n.º 10
0
def orthantree(scaled, capacity=8):
    """Construct a D-dimensional adaptively-refined quadtree/octtree/etc over the given problem.
    Stop subdividing when each leaf box has at most `capacity` sources points and at most 
    `capacity` target points in it.

    This is a bit of a mess of a function, but long story short it starts with all the sources
    allocated to the root and repeatedly subdivides overfull boxes. The boxes are represented
    as an index, with the root being index 0. This means that you'll usually find the attribute 
    of box `i` at index `i` of an array: parents[3] gives the index of the parent of `3`,
    children[3] gives the children of `3`, etc etc.

    Anyway, out of this function you get three datastructures.

    tree: an arrdict describing the tree itself. Thinking of indexing into the arrays in this dict
        as a map of sorts,
            * `parents`: maps boxes to their parents
            * `depths`: maps boxes to their depth in the tree
            * `centers`: maps boxes to their physical center
            * `terminal`: maps boxes to a boolean saying whether that box is a leaf
            * `children`: maps boxes to a (2,)/(2, 2)/(2, 2, 2)/etc array of 2**D children
            * `descent`: maps boxes to a (D,)-vector of what kind of child that box is, with elements from (-1, +1).
    indices: an arrdict mapping sources and targets to the leaf box they lie in.
    depths: a ragged array mapping each depth to the boxes at that depth.
    
    """
    D = scaled.sources.shape[1]

    points = torch.cat([scaled.sources, scaled.targets])
    indices = points.new_zeros((len(points), ), dtype=torch.long)

    tree = arrdict.arrdict(parents=indices.new_full((1, ), -1),
                           depths=indices.new_zeros((1, )),
                           centers=points.new_zeros((1, D)),
                           terminal=indices.new_ones((1, ), dtype=torch.bool),
                           children=indices.new_full((1, ) + (2, ) * D, -1),
                           descent=indices.new_zeros((1, D)))

    bases = 2**torch.flip(torch.arange(D, device=indices.device), (0, ))
    subscript_offsets = sets.cartesian_product(
        torch.tensor([0, 1], device=indices.device), D)
    center_offsets = sets.cartesian_product(
        torch.tensor([-1, +1], device=indices.device), D)

    depthcounts = [torch.as_tensor([1], device=indices.device)]

    depth = 0
    while True:
        used, used_inv = torch.unique(indices, return_inverse=True)
        source_idxs, target_idxs = indices[:len(scaled.sources)], indices[
            -len(scaled.targets):]
        tree.terminal[used] = underoccupied(source_idxs, target_idxs,
                                            tree.terminal, capacity)[used]

        used_is_active = ~tree.terminal[used]
        point_is_active = used_is_active[used_inv]
        if not point_is_active.any():
            break

        depth += 1

        active = used[used_is_active]
        active_inv = (used_is_active.cumsum(0) -
                      used_is_active.long())[used_inv[point_is_active]]
        first_child = len(tree.parents) + 2**D * torch.arange(
            len(active), device=active.device)
        point_offset = (
            (points[point_is_active] >= tree.centers[active][active_inv]) *
            bases).sum(-1)
        child_box = first_child[active_inv] + point_offset
        indices[point_is_active] = child_box

        trailing_ones = (slice(None), ) + (None, ) * D
        tree.children[active] = first_child[trailing_ones] + (
            subscript_offsets * bases).sum(-1)

        centers = tree.centers[active][trailing_ones] + center_offsets.float(
        ) / 2**depth
        descent = center_offsets[None].expand_as(centers)

        n_children = len(active) * 2**D
        children = arrdict.arrdict(parents=active.repeat_interleave(2**D),
                                   depths=tree.depths.new_full((n_children, ),
                                                               depth),
                                   centers=centers.reshape(-1, D),
                                   descent=descent.reshape(-1, D),
                                   terminal=tree.terminal.new_ones(
                                       (n_children, )),
                                   children=tree.children.new_full(
                                       (n_children, ) + (2, ) * D, -1))
        tree = arrdict.cat([tree, children])

        depthcounts.append(n_children)

    tree['id'] = torch.arange(len(tree.parents), device=points.device)

    indices = arrdict.arrdict(sources=indices[:len(scaled.sources)],
                              targets=indices[-len(scaled.targets):])

    depths = ragged.Ragged(torch.arange(len(tree.id), device=points.device),
                           torch.as_tensor(depthcounts, device=points.device))

    return tree, indices, depths