def staggered_elements(resolution: Shape, bounds: Box, extrapolation: math.Extrapolation): cells = GridCell(resolution, bounds) grids = [] for dim in resolution.names: lower, upper = extrapolation.valid_outer_faces(dim) grids.append(cells.stagger(dim, lower, upper)) return geom.stack(grids, channel(staggered_direction=resolution.names))
def batch_stack(*fields, dim: str): assert all(isinstance(f, SampledField) for f in fields) assert all(isinstance(f, type(fields[0])) for f in fields) if any(f.extrapolation != fields[0].extrapolation for f in fields): raise NotImplementedError("Concatenating extrapolations not supported") if isinstance(fields[0], Grid): values = math.batch_stack([f.values for f in fields], dim) return fields[0].with_(values=values) elif isinstance(fields[0], PointCloud): elements = geom.stack(*[f.elements for f in fields], dim=dim) values = math.batch_stack([f.values for f in fields], dim=dim) colors = math.batch_stack([f.color for f in fields], dim=dim) return fields[0].with_(elements=elements, values=values, color=colors) raise NotImplementedError(type(fields[0]))
def stack(fields, dim: Shape, dim_bounds: Box = None): """ Stacks the given `SampledField`s along `dim`. See Also: `concat()`. Args: fields: List of matching `SampledField` instances. dim: Stack dimension as `Shape`. Size is ignored. Returns: `SampledField` matching stacked fields. """ assert all( isinstance(f, SampledField) for f in fields ), f"All fields must be SampledFields of the same type but got {fields}" assert all( isinstance(f, type(fields[0])) for f in fields ), f"All fields must be SampledFields of the same type but got {fields}" if any(f.extrapolation != fields[0].extrapolation for f in fields): raise NotImplementedError("Concatenating extrapolations not supported") if isinstance(fields[0], Grid): values = math.stack([f.values for f in fields], dim) if spatial(dim): if dim_bounds is None: dim_bounds = Box(**{dim.name: len(fields)}) return type(fields[0])(values, extrapolation=fields[0].extrapolation, bounds=fields[0].bounds * dim_bounds) else: return fields[0].with_values(values) elif isinstance(fields[0], PointCloud): elements = geom.stack([f.elements for f in fields], dim=dim) values = math.stack([f.values for f in fields], dim=dim) colors = math.stack([f.color for f in fields], dim=dim) return PointCloud(elements=elements, values=values, color=colors, extrapolation=fields[0].extrapolation, add_overlapping=fields[0]._add_overlapping, bounds=fields[0]._bounds) raise NotImplementedError(type(fields[0]))
def test_stack_type(self): bounds1 = Box[0:1, 0:1] bounds2 = Box[0:10, 0:10] bounds = geom.stack([bounds1, bounds2], batch('batch')) self.assertIsInstance(bounds, Box)
def test_stack_volume(self): u = geom.stack([Box[0:1, 0:1], Box[0:2, 0:2]], batch('batch')) math.assert_close(u.volume, [1, 4])
def test_laplace_batched(self): bounds = geom.stack([Box[0:1, 0:1], Box[0:10, 0:10]], batch('batch')) grid = CenteredGrid(0, extrapolation.ZERO, bounds, x=10, y=10) lap = field.laplace(grid) self.assertIsInstance(lap, CenteredGrid)
def test_spatial_gradient_batched(self): bounds = geom.stack([Box[0:1, 0:1], Box[0:10, 0:10]], batch('batch')) grid = CenteredGrid(0, extrapolation.ZERO, bounds, x=10, y=10) grad = field.spatial_gradient(grid) self.assertIsInstance(grad, CenteredGrid)