def dot_general_dependency_rule(outstart, outcount, lhs, rhs, dimension_numbers, precision): if not is_ones(outcount): raise NotImplementedError outshape = outcount.shape outslices = list(zip(outstart, outshape)) (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_other_out_dims = list( range(len(lhs_batch), len(lhs.shape) - len(lhs_contracting))) rhs_other_out_dims = list( range(len(rhs_batch) + len(lhs_other_out_dims), len(outshape))) lhs_outstart, lhs_outshape = unzip2( [outslices[d] for d in list(lhs_batch) + lhs_other_out_dims]) (lhs_box, ), (lhs_count, ), _ = reduce_dependency_rule(None)( lhs_outstart, Ones(lhs_outshape), lhs, axes=lhs_contracting) rhs_outstart, rhs_outshape = unzip2( [outslices[d] for d in list(rhs_batch) + rhs_other_out_dims]) (rhs_box, ), (rhs_count, ), _ = reduce_dependency_rule(None)( rhs_outstart, Ones(rhs_outshape), rhs, axes=rhs_contracting) incounts = [ materialize(lhs_count) * prod(np.take(outshape, rhs_other_out_dims)) if isinstance(lhs, LazyArray) else None, materialize(rhs_count) * prod(np.take(outshape, lhs_other_out_dims)) if isinstance(rhs, LazyArray) else None ] return ([lhs_box, rhs_box], incounts, lambda *inslices: lax.dot_general( *inslices, dimension_numbers, precision))
def pad_dependency_rule(outstart, outcount, operand, padding_value, padding_config): lo, _, interior = unzip3(padding_config) dilation = np.array(interior) + 1 outstart_lo = np.subtract(outstart, lo) inclip = lambda indices: np.clip(indices, 0, operand.shape) instart = inclip(lax.lax._ceil_divide(outstart_lo, dilation)) instop = inclip( lax.lax._ceil_divide(outstart_lo + outcount.shape, dilation)) inshape = instop - instart insize = prod(inshape) offset = instart * dilation - outstart_lo limit = offset + np.maximum(0, (np.array(inshape) - 1) * dilation + 1) incount = Ones(inshape) if is_ones(outcount) else laxref.slice( outcount, offset, limit, dilation) if insize else None padcount = outcount.size - insize def outslice(inslice, padding_value): assert inslice is None or np.array_equal(inslice.shape, inshape) return (lax.pad( inslice, padding_value, zip(offset, np.array(outcount.shape) - limit, interior)) if insize else jnp.full(outcount.shape, padding_value, operand.dtype)) return ([(instart, inshape) if insize else None, ([], [])], [incount, padcount], outslice)
def transpose_dependency_rule(outstart, outcount, operand, permutation): inverse_perm = np.argsort(permutation) inshape = np.take(outcount.shape, inverse_perm) return ([(np.take(outstart, inverse_perm), inshape)], [ Ones(inshape) if is_ones(outcount) else np.transpose( outcount, inverse_perm) ], lambda inslice: lax.transpose(inslice, permutation))
def concatenate_dependency_rule(outstart, outcount, *operands, dimension): if not is_ones(outcount): raise NotImplementedError dim = dimension outstart, outshape = list(outstart), list(outcount.shape) dimstart, dimshape = outstart[dim], outshape[dim] position = 0 inboxes = [] incounts = [] for operand in operands: shape = operand.shape if dimstart < position + shape[dim] and position < dimstart + dimshape: instart = (outstart[:dim] + [max(0, dimstart - position)] + outstart[dim + 1:]) inshape = (outshape[:dim] + [ min(dimstart + dimshape - position, shape[dim], dimshape, position + shape[dim] - instart[dim]) ] + outshape[dim + 1:]) inboxes.append((instart, inshape)) incounts.append(Ones(inshape)) else: inboxes.append(None) incounts.append(None) position += shape[dim] return inboxes, incounts, lambda *inslices: lax.concatenate( [x for x in inslices if x is not None], dimension)
def reduce_dependency_rule(prim, outstart, outcount, operand, axes, **kwargs): if not is_ones(outcount): raise NotImplementedError instart = list(outstart) inshape = list(outcount.shape) for d in np.sort(axes): instart.insert(d, 0) inshape.insert(d, operand.shape[d]) return ([(instart, inshape)], [Ones(inshape)], lambda inslice: prim.bind(inslice, axes=axes, **kwargs))
def squeeze_dependency_rule(outstart, outcount, operand, dimensions): if not is_ones(outcount): raise NotImplementedError instart = list(outstart) inshape = list(outcount.shape) for d in np.sort(dimensions): instart.insert(d, 0) inshape.insert(d, 1) return ([(instart, inshape)], [Ones(inshape)], lambda inslice: lax.squeeze(inslice, dimensions))
def rev_dependency_rule(outstart, outcount, operand, dimensions): instart = [ size - (start + outsize) if d in dimensions else start for d, ( size, outsize, start) in enumerate(zip(operand.shape, outcount.shape, outstart)) ] return ([(instart, outcount.shape)], [ Ones(outcount.shape) if is_ones(outcount) else lax.rev( outcount, dimensions) ], lambda inslice: lax.rev(inslice, dimensions))
def test_conv_incounts_strided(): rhs_shape = (1, 1, 7) lhs_count, rhs_count = conv_incounts((1, 1, 21), rhs_shape, (3, )) np.testing.assert_array_equal( [[[1, 1, 1, 2, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 1, 1, 1, 0, 0]]], lhs_count) np.testing.assert_array_equal(np.ones(rhs_shape), rhs_count) @pytest.mark.parametrize( 'outstart,outcount,expected_incount', [ ((7, ), np.arange(7), [0, 3]), # outstart aligned with padded element ((7, ), np.arange(1), [0]), # outstart + outstop aligned with padded element ((7, ), Ones((1, )), Ones((1, ))), # outcount Ones ((6, ), np.arange(8), [1, 4]), # outstart in interior padding ((6, ), np.arange(3), [1]), # outstart, outstop in interior padding ((0, ), np.arange(4), None), # outstart in lo, no elements ((1, ), np.arange(4), [3]), # outstart in lo, including padded element ((15, ), np.arange(1), None), # outstart in hi ]) def test_pad_dependency_rule(outstart, outcount, expected_incount): (instart, _), (incount, _), outslice = pad_dependency_rule(outstart, outcount, np.arange(3), 0, ((4, 4, 2), )) np.testing.assert_array_equal(expected_incount, incount) inslice = None if incount is None else np.ones(incount.shape, int) assert outslice(inslice, 0).shape == outcount.shape