Beispiel #1
0
def dot_general_dependency_rule(outstart, outcount, lhs, rhs,
                                dimension_numbers, precision):
    if not is_ones(outcount):
        raise NotImplementedError
    outshape = outcount.shape
    outslices = list(zip(outstart, outshape))
    (lhs_contracting, rhs_contracting), (lhs_batch,
                                         rhs_batch) = dimension_numbers
    lhs_other_out_dims = list(
        range(len(lhs_batch),
              len(lhs.shape) - len(lhs_contracting)))
    rhs_other_out_dims = list(
        range(len(rhs_batch) + len(lhs_other_out_dims), len(outshape)))
    lhs_outstart, lhs_outshape = unzip2(
        [outslices[d] for d in list(lhs_batch) + lhs_other_out_dims])
    (lhs_box, ), (lhs_count, ), _ = reduce_dependency_rule(None)(
        lhs_outstart, Ones(lhs_outshape), lhs, axes=lhs_contracting)
    rhs_outstart, rhs_outshape = unzip2(
        [outslices[d] for d in list(rhs_batch) + rhs_other_out_dims])
    (rhs_box, ), (rhs_count, ), _ = reduce_dependency_rule(None)(
        rhs_outstart, Ones(rhs_outshape), rhs, axes=rhs_contracting)
    incounts = [
        materialize(lhs_count) * prod(np.take(outshape, rhs_other_out_dims))
        if isinstance(lhs, LazyArray) else None,
        materialize(rhs_count) * prod(np.take(outshape, lhs_other_out_dims))
        if isinstance(rhs, LazyArray) else None
    ]
    return ([lhs_box, rhs_box], incounts, lambda *inslices: lax.dot_general(
        *inslices, dimension_numbers, precision))
Beispiel #2
0
def pad_dependency_rule(outstart, outcount, operand, padding_value,
                        padding_config):
    lo, _, interior = unzip3(padding_config)
    dilation = np.array(interior) + 1
    outstart_lo = np.subtract(outstart, lo)
    inclip = lambda indices: np.clip(indices, 0, operand.shape)
    instart = inclip(lax.lax._ceil_divide(outstart_lo, dilation))
    instop = inclip(
        lax.lax._ceil_divide(outstart_lo + outcount.shape, dilation))
    inshape = instop - instart
    insize = prod(inshape)
    offset = instart * dilation - outstart_lo
    limit = offset + np.maximum(0, (np.array(inshape) - 1) * dilation + 1)
    incount = Ones(inshape) if is_ones(outcount) else laxref.slice(
        outcount, offset, limit, dilation) if insize else None
    padcount = outcount.size - insize

    def outslice(inslice, padding_value):
        assert inslice is None or np.array_equal(inslice.shape, inshape)
        return (lax.pad(
            inslice, padding_value,
            zip(offset,
                np.array(outcount.shape) - limit, interior)) if insize else
                jnp.full(outcount.shape, padding_value, operand.dtype))

    return ([(instart, inshape) if insize else None,
             ([], [])], [incount, padcount], outslice)
Beispiel #3
0
def transpose_dependency_rule(outstart, outcount, operand, permutation):
    inverse_perm = np.argsort(permutation)
    inshape = np.take(outcount.shape, inverse_perm)
    return ([(np.take(outstart, inverse_perm), inshape)], [
        Ones(inshape) if is_ones(outcount) else np.transpose(
            outcount, inverse_perm)
    ], lambda inslice: lax.transpose(inslice, permutation))
Beispiel #4
0
def concatenate_dependency_rule(outstart, outcount, *operands, dimension):
    if not is_ones(outcount):
        raise NotImplementedError
    dim = dimension
    outstart, outshape = list(outstart), list(outcount.shape)
    dimstart, dimshape = outstart[dim], outshape[dim]
    position = 0
    inboxes = []
    incounts = []
    for operand in operands:
        shape = operand.shape
        if dimstart < position + shape[dim] and position < dimstart + dimshape:
            instart = (outstart[:dim] + [max(0, dimstart - position)] +
                       outstart[dim + 1:])
            inshape = (outshape[:dim] + [
                min(dimstart + dimshape - position, shape[dim], dimshape,
                    position + shape[dim] - instart[dim])
            ] + outshape[dim + 1:])
            inboxes.append((instart, inshape))
            incounts.append(Ones(inshape))
        else:
            inboxes.append(None)
            incounts.append(None)
        position += shape[dim]

    return inboxes, incounts, lambda *inslices: lax.concatenate(
        [x for x in inslices if x is not None], dimension)
Beispiel #5
0
def reduce_dependency_rule(prim, outstart, outcount, operand, axes, **kwargs):
    if not is_ones(outcount):
        raise NotImplementedError
    instart = list(outstart)
    inshape = list(outcount.shape)
    for d in np.sort(axes):
        instart.insert(d, 0)
        inshape.insert(d, operand.shape[d])
    return ([(instart, inshape)], [Ones(inshape)],
            lambda inslice: prim.bind(inslice, axes=axes, **kwargs))
Beispiel #6
0
def squeeze_dependency_rule(outstart, outcount, operand, dimensions):
    if not is_ones(outcount):
        raise NotImplementedError
    instart = list(outstart)
    inshape = list(outcount.shape)
    for d in np.sort(dimensions):
        instart.insert(d, 0)
        inshape.insert(d, 1)
    return ([(instart, inshape)], [Ones(inshape)],
            lambda inslice: lax.squeeze(inslice, dimensions))
Beispiel #7
0
def rev_dependency_rule(outstart, outcount, operand, dimensions):
    instart = [
        size - (start + outsize) if d in dimensions else start
        for d, (
            size, outsize,
            start) in enumerate(zip(operand.shape, outcount.shape, outstart))
    ]
    return ([(instart, outcount.shape)], [
        Ones(outcount.shape) if is_ones(outcount) else lax.rev(
            outcount, dimensions)
    ], lambda inslice: lax.rev(inslice, dimensions))
Beispiel #8
0
def test_conv_incounts_strided():
    rhs_shape = (1, 1, 7)
    lhs_count, rhs_count = conv_incounts((1, 1, 21), rhs_shape, (3, ))
    np.testing.assert_array_equal(
        [[[1, 1, 1, 2, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 1, 1, 1, 0, 0]]],
        lhs_count)
    np.testing.assert_array_equal(np.ones(rhs_shape), rhs_count)


@pytest.mark.parametrize(
    'outstart,outcount,expected_incount',
    [
        ((7, ), np.arange(7), [0, 3]),  # outstart aligned with padded element
        ((7, ), np.arange(1),
         [0]),  # outstart + outstop aligned with padded element
        ((7, ), Ones((1, )), Ones((1, ))),  # outcount Ones
        ((6, ), np.arange(8), [1, 4]),  # outstart in interior padding
        ((6, ), np.arange(3), [1]),  # outstart, outstop in interior padding
        ((0, ), np.arange(4), None),  # outstart in lo, no elements
        ((1, ), np.arange(4), [3]),  # outstart in lo, including padded element
        ((15, ), np.arange(1), None),  # outstart in hi
    ])
def test_pad_dependency_rule(outstart, outcount, expected_incount):
    (instart, _), (incount,
                   _), outslice = pad_dependency_rule(outstart, outcount,
                                                      np.arange(3), 0,
                                                      ((4, 4, 2), ))
    np.testing.assert_array_equal(expected_incount, incount)
    inslice = None if incount is None else np.ones(incount.shape, int)
    assert outslice(inslice, 0).shape == outcount.shape