def test_tensordot__03(): larray = np.random.rand(2,3,4,5,6) rarray = np.random.rand(6,4,5,3,3) tarray = np.tensordot(larray, rarray, axes=([1,2,3,4],[3,1,2,0])) ltile = tl.Tile((2,3,4,5,6), None ) rtile = tl.Tile((6,4,5,3,3), rarray) ttile = tl.tensordot(ltile, rtile, axis_keys=([1,2,3,4],[3,1,2,0])) assert tarray.shape == ttile.get_shape() assert ttile.is_empty()
def test_tensordot__01(): larray = np.random.rand(2,3,4,5,6) rarray = np.random.rand(6,4,5,3,3) tarray = np.tensordot(larray, rarray, axes=([1,2,3,4],[3,1,2,0])) ltile = tl.Tile((2,3,4,5,6), larray) rtile = tl.Tile((6,4,5,3,3), rarray) ttile = tl.tensordot(ltile, rtile, axis_keys=([1,2,3,4],[3,1,2,0])) assert tarray.shape == ttile.get_shape() for coord in it.product(*(range(dim) for dim in tarray.shape)): assert ttile[coord] == tarray[coord]
def tensordot(L, R, axis_keys=([0],[0])): sum_axis_keys_L, sum_axis_keys_R = axis_keys ncontracted = len(sum_axis_keys_L) L = L.transpose(sum_axis_keys_L + [axis_key for axis_key in L.iter_axis_keys() if not axis_key in sum_axis_keys_L]) R = R.transpose(sum_axis_keys_R + [axis_key for axis_key in R.iter_axis_keys() if not axis_key in sum_axis_keys_R]) sum_axes = L.get_axis(slice(None,ncontracted)) row_axes = L.get_axis(slice(ncontracted,None)) col_axes = R.get_axis(slice(ncontracted,None)) contract_tiles = lambda tile1, tile2: tl.tensordot(tile1, tile2, axis_keys=(range(ncontracted), range(ncontracted))) T = TiledTensor(row_axes + col_axes) for r in multi_axis_iter(row_axes): for c in multi_axis_iter(col_axes): for s in multi_axis_iter(sum_axes): T._tiles[r+c] += contract_tiles(L.get_tile(s+r), R.get_tile(s+c)) return T
tarray = np.tensordot(larray, rarray, axes=([1,2,3,4],[3,1,2,0])) ltile = tl.Tile((2,3,4,5,6), None ) rtile = tl.Tile((6,4,5,3,3), rarray) ttile = tl.tensordot(ltile, rtile, axis_keys=([1,2,3,4],[3,1,2,0])) assert tarray.shape == ttile.get_shape() assert ttile.is_empty() if __name__ == "__main__": L = tl.Tile((5, 7), np.ones((5, 7))) R = tl.Tile((5, 7)) T = L + R print T T = tl.Tile((2,3,4), np.arange(2*3*4).reshape((2,3,4))) print T.get_shape() print T print T.transpose().get_shape() print T.transpose() larray = np.random.rand(2,3,4,5,6) rarray = np.random.rand(6,4,5,3,3) print [larray.shape[ax] for ax in [1,2,3,4]] print [rarray.shape[ax] for ax in [3,1,2,0]] tarray = np.tensordot(larray, rarray, axes=([1,2,3,4],[3,1,2,0])) print tarray L = tl.Tile((2,3,4,5,6), larray) R = tl.Tile((6,4,5,3,3), rarray) contract_tiles = lambda tile1, tile2: tl.tensordot(tile1, tile2, axis_keys=([1,2,3,4],[3,1,2,0])) T = contract_tiles(L, R) print T._array