def staggered_elements(resolution: Shape, bounds: Box, extrapolation: math.Extrapolation): cells = GridCell(resolution, bounds) grids = [] for dim in resolution.names: lower, upper = extrapolation.valid_outer_faces(dim) grids.append(cells.stagger(dim, lower, upper)) return geom.stack(grids, channel(staggered_direction=resolution.names))
def expand_staggered(values: Tensor, resolution: Shape, extrapolation: math.Extrapolation): """ Add missing spatial dimensions to `values` """ cells = GridCell( resolution, Box( 0, math.wrap((1, ) * resolution.rank, channel(vector=resolution.names)))) components = values.vector.unstack(resolution.spatial_rank) tensors = [] for dim, component in zip(resolution.spatial.names, components): comp_cells = cells.stagger(dim, *extrapolation.valid_outer_faces(dim)) tensors.append(math.expand(component, comp_cells.resolution)) return math.stack(tensors, channel(vector=resolution.names))