def stagger(field: CenteredGrid, face_function: Callable, extrapolation: math.extrapolation.Extrapolation, type: type = StaggeredGrid): """ Creates a new grid by evaluating `face_function` given two neighbouring cells. One layer of missing cells is inferred from the extrapolation. This method returns a Field of type `type` which must be either StaggeredGrid or CenteredGrid. When returning a StaggeredGrid, the new values are sampled at the faces of neighbouring cells. When returning a CenteredGrid, the new grid has the same resolution as `field`. Args: field: centered grid face_function: function mapping (value1: Tensor, value2: Tensor) -> center_value: Tensor extrapolation: extrapolation mode of the returned grid. Has no effect on the values. type: one of (StaggeredGrid, CenteredGrid) field: CenteredGrid: face_function: Callable: extrapolation: math.extrapolation.Extrapolation: type: type: (Default value = StaggeredGrid) Returns: grid of type matching the `type` argument """ all_lower = [] all_upper = [] if type == StaggeredGrid: for dim in field.shape.spatial.names: lo_valid, up_valid = extrapolation.valid_outer_faces(dim) width_lower = {dim: (int(lo_valid), int(up_valid) - 1)} width_upper = { dim: (int(lo_valid or up_valid) - 1, int(lo_valid and up_valid)) } all_lower.append( math.pad(field.values, width_lower, field.extrapolation)) all_upper.append( math.pad(field.values, width_upper, field.extrapolation)) all_upper = math.stack(all_upper, channel('vector')) all_lower = math.stack(all_lower, channel('vector')) values = face_function(all_lower, all_upper) result = StaggeredGrid(values, bounds=field.bounds, extrapolation=extrapolation) assert result.shape.spatial == field.shape.spatial return result elif type == CenteredGrid: left, right = math.shift(field.values, (-1, 1), padding=field.extrapolation, stack_dim=channel('vector')) values = face_function(left, right) return CenteredGrid(values, bounds=field.bounds, extrapolation=extrapolation) else: raise ValueError(type)
def __mul__(self, other): if not isinstance(other, Box): return NotImplemented lower = self._lower.vector.unstack( self.spatial_rank) + other._lower.vector.unstack(self.spatial_rank) upper = self._upper.vector.unstack( self.spatial_rank) + other._upper.vector.unstack(self.spatial_rank) names = self._upper.vector.item_names + other._upper.vector.item_names lower = math.stack(lower, math.channel(vector=names)) upper = math.stack(upper, math.channel(vector=names)) return Box(lower, upper)
def staggered_points(self, dimension): idx_zyx = np.meshgrid(*[ np.arange(0.5, dim + 1.5, 1) if dim != dimension else np.arange( 0, dim + 1, 1) for dim in self.resolution ], indexing="ij") return math.expand_dims(math.stack(idx_zyx, axis=-1), 0)
def union(*geometries) -> Geometry: """ Union of the given geometries. A point lies inside the union if it lies within at least one of the geometries. Args: geometries: arbitrary geometries with same spatial dims. Arbitrary batch dims are allowed. *geometries: Returns: union Geometry """ if len(geometries) == 1 and isinstance(geometries[0], (tuple, list)): geometries = geometries[0] if len(geometries) == 0: return NO_GEOMETRY elif len(geometries) == 1: return geometries[0] elif all(type(g) == type(geometries[0]) for g in geometries): attrs = variable_attributes(geometries[0]) values = { a: math.stack([getattr(g, a) for g in geometries], math.instance('union')) for a in attrs } return copy_with(geometries[0], **values) else: base_geometries = () for geometry in geometries: base_geometries += geometry.geometries if isinstance( geometry, Union) else (geometry, ) return Union(base_geometries)
def at_faces(self, face_dimension_xyz): dims = range(self.spatial_rank) face_dimension_zyx = len( dims) - face_dimension_xyz - 1 # 0=Z, 1=Y, 2=X, etc. components = [] for d in dims: # z,y,x if d == face_dimension_zyx: components.append(self.staggered[..., len(dims) - d - 1]) else: # Interpolate other components vq = self.staggered[..., len(dims) - d - 1] t = vq for d2 in dims: # z,y,x slices1 = [(slice(1, None) if i == d2 else slice(None)) for i in dims] slices2 = [(slice(-1) if i == d2 else slice(None)) for i in dims] t = t[[slice(None)] + slices1] + t[[slice(None)] + slices2] if d2 == d: t = math.pad( t, [[0, 0]] + [([0, 1] if i == d2 else [0, 0]) for i in dims]) / 2 else: t = math.pad( t, [[0, 0]] + [([1, 0] if i == d2 else [0, 0]) for i in dims]) / 2 components.append(t) return math.stack(components[::-1], axis=-1)
def reduce_sample(field: Field, geometry: Geometry, dim=channel('vector')) -> math.Tensor: """ Similar to `sample()`, but matches channel dimensions of `geometry` with channel dimensions of this field. Currently, `geometry` may have at most one channel dimension. See Also: `sample()`, `Field.at()`, [Resampling overview](https://tum-pbs.github.io/PhiFlow/Fields.html#resampling-fields). Args: field: Source `Field` to sample. geometry: Single or batched `phi.geom.Geometry`. dim: Dimension of result, resulting from reduction of channel dimensions. Returns: Sampled values as a `phi.math.Tensor` """ if isinstance(field, SampledField) and field.elements.shallow_equals(geometry): return field.values if geometry.shape.channel: # Reduce this dimension assert geometry.shape.channel.rank == 1, "Only single-dimension reduction supported." if field.shape.channel.volume > 1: assert field.shape.channel.volume == geometry.shape.channel.volume, f"Cannot sample field with channels {field.shape.channel} at elements with channels {geometry.shape.channel}." components = unstack(field, field.shape.channel.name) sampled = [c._sample(p) for c, p in zip(components, geometry.unstack(geometry.shape.channel.name))] else: sampled = [field._sample(p) for p in geometry.unstack(geometry.shape.channel.name)] dim = dim._with_item_names(geometry.shape.channel.item_names) return math.stack(sampled, dim) else: # Nothing to reduce return field._sample(geometry)
def sample(field: Field, geometry: Geometry) -> math.Tensor: """ Computes the field value inside the volume of the (batched) `geometry`. The field value may be determined by integrating over the volume, sampling the central value or any other way. The batch dimensions of `geometry` are matched with this field. The `geometry` must not share any channel dimensions with this field. Spatial dimensions of `geometry` can be used to sample a grid of geometries. See Also: `reduce_sample()`, `Field.at()`, [Resampling overview](https://tum-pbs.github.io/PhiFlow/Fields.html#resampling-fields). Args: field: Source `Field` to sample. geometry: Single or batched `phi.geom.Geometry`. Returns: Sampled values as a `phi.math.Tensor` """ assert all(dim not in field.shape for dim in geometry.shape.channel) if isinstance(field, SampledField) and field.elements.shallow_equals(geometry) and not geometry.shape.channel: return field.values if geometry.shape.channel: sampled = [field._sample(p) for p in geometry.unstack(geometry.shape.channel.name)] return math.stack(sampled, geometry.shape.channel) else: return field._sample(geometry)
def lies_inside(self, location: math.Tensor): if self.geometries.shape in location.shape: location = location.unstack(self.geometries.shape.name) else: location = [location] * len(self.geometries) inside = [g.lies_inside(loc) for g, loc in zip(self.geometries, location)] return math.stack(inside, self.geometries.shape)
def upsample2x(tensor, interpolation="LINEAR"): if interpolation.lower() != "linear": raise ValueError("Only linear interpolation supported") dims = range(spatial_rank(tensor)) vlen = tensor.shape[-1] spatial_dims = tensor.shape[1:-1] tensor = math.pad(tensor, [[0, 0]] + [[1, 1]] * spatial_rank(tensor) + [[0, 0]], "SYMMETRIC") for dim in dims: left_slices_1 = [(slice(2, None) if i == dim else slice(None)) for i in dims] left_slices_2 = [(slice(1, -1) if i == dim else slice(None)) for i in dims] right_slices_1 = [(slice(1, -1) if i == dim else slice(None)) for i in dims] right_slices_2 = [(slice(-2) if i == dim else slice(None)) for i in dims] left = 0.75 * tensor[[slice(None)] + left_slices_2 + [slice(None)]] + 0.25 * tensor[ [slice(None)] + left_slices_1 + [slice(None)]] right = 0.25 * tensor[[slice(None)] + right_slices_2 + [ slice(None) ]] + 0.75 * tensor[[slice(None)] + right_slices_1 + [slice(None)]] combined = math.stack([right, left], axis=2 + dim) tensor = math.reshape(combined, [-1] + [ spatial_dims[dim] * 2 if i == dim else tensor.shape[i + 1] for i in dims ] + [vlen]) return tensor
def with_extrapolation(self, extrapolation: math.Extrapolation): if all( extrapolation.valid_outer_faces(dim) == self.extrapolation.valid_outer_faces(dim) for dim in self.resolution.names): return StaggeredGrid(self.values, extrapolation=extrapolation, bounds=self.bounds) else: values = [] for dim, component in zip(self.shape.spatial.names, self.values.unstack('vector')): old_lo, old_hi = [ int(v) for v in self.extrapolation.valid_outer_faces(dim) ] new_lo, new_hi = [ int(v) for v in extrapolation.valid_outer_faces(dim) ] widths = (new_lo - old_lo, new_hi - old_hi) values.append( math.pad(component, {dim: widths}, self.extrapolation)) values = math.stack(values, channel('vector')) return StaggeredGrid(values, extrapolation=extrapolation, bounds=self.bounds)
def test_stacked_shapes(self): t0 = math.ones(batch(batch=10) & spatial(x=4, y=3) & channel(vector=2)) for dim in t0.shape.names: tensors = t0.unstack(dim) stacked = math.stack(tensors, t0.shape[dim].with_sizes([None])) self.assertEqual(set(t0.shape.names), set(stacked.shape.names)) self.assertEqual(t0.shape.volume, stacked.shape.volume)
def distribute_points(density, particles_per_cell=1, distribution='uniform'): """ Distribute points according to the distribution specified in density. :param density: binary tensor :param particles_per_cell: integer :param distribution: 'uniform' or 'center' :return: tensor of shape (batch_size, point_count, rank) """ assert distribution in ('center', 'uniform') index_array = [] batch_size = math.staticshape(density)[0] if math.staticshape(density)[0] is not None else 1 for batch in range(batch_size): indices = math.where(density[batch, ..., 0] > 0) indices = math.to_float(indices) temp = [] for _ in range(particles_per_cell): if distribution == 'center': temp.append(indices + 0.5) elif distribution == 'uniform': temp.append(indices + math.random_uniform(math.shape(indices))) index_array.append(math.concat(temp, axis=0)) try: index_array = math.stack(index_array) return index_array except ValueError: raise ValueError("all arrays in the batch must have the same number of active cells.")
def stack(geometries: List[Geometry], dim: Shape): """ Stacks `geometries` along `dim`. The size of `dim` is ignored. """ if all(type(g) == type(geometries[0]) and not isinstance(g, GridCell) for g in geometries): attrs = variable_attributes(geometries[0]) new_attributes = {a: math.stack([getattr(g, a) for g in geometries], dim) for a in attrs} return copy_with(geometries[0], **new_attributes) return GeometryStack(math.layout(geometries, dim))
def downsample2x(grid: Grid) -> GridType: """ Reduces the number of sample points by a factor of 2 in each spatial dimension. The new values are determined via linear interpolation. See Also: `upsample2x()`. Args: grid: `CenteredGrid` or `StaggeredGrid`. Returns: `Grid` of same type as `grid`. """ if isinstance(grid, CenteredGrid): values = math.downsample2x(grid.values, grid.extrapolation) return CenteredGrid(values, bounds=grid.bounds, extrapolation=grid.extrapolation) elif isinstance(grid, StaggeredGrid): values = [] for dim, centered_grid in zip(grid.shape.spatial.names, unstack(grid, 'vector')): odd_discarded = centered_grid.values[{dim: slice(None, None, 2)}] others_interpolated = math.downsample2x( odd_discarded, grid.extrapolation, dims=grid.shape.spatial.without(dim)) values.append(others_interpolated) return StaggeredGrid(math.stack(values, channel('vector')), bounds=grid.bounds, extrapolation=grid.extrapolation) else: raise ValueError(type(grid))
def bake_extrapolation(grid: GridType) -> GridType: """ Pads `grid` with its current extrapolation. For `StaggeredGrid`s, the resulting grid will have a consistent shape, independent of the original extrapolation. Args: grid: `CenteredGrid` or `StaggeredGrid`. Returns: Padded grid with extrapolation `phi.math.extrapolation.NONE`. """ if grid.extrapolation == math.extrapolation.NONE: return grid if isinstance(grid, StaggeredGrid): values = grid.values.unstack('vector') padded = [] for dim, value in zip(grid.shape.spatial.names, values): lower, upper = grid.extrapolation.valid_outer_faces(dim) padded.append( math.pad(value, {dim: (0 if lower else 1, 0 if upper else 1)}, grid.extrapolation)) return StaggeredGrid(math.stack(padded, channel('vector')), bounds=grid.bounds, extrapolation=math.extrapolation.NONE) elif isinstance(grid, CenteredGrid): return pad(grid, 1).with_extrapolation(math.extrapolation.NONE) else: raise ValueError(f"Not a valid grid: {grid}")
def test_stacked_get(self): t0 = math.ones(batch(batch=10) & spatial(x=4, y=3) & channel(vector=2)) tensors = t0.unstack('vector') stacked = math.stack(tensors, channel('channel')) self.assertEqual(tensors, stacked.channel.unstack()) assert tensors[0] is stacked.channel[0] assert tensors[1] is stacked.channel[1:2].channel.unstack()[0] self.assertEqual(4, len(stacked.x.unstack()))
def test_plot_multi_1d(self): self._test_plot( CenteredGrid( lambda x: math.stack({ 'sin': math.sin(x), 'cos': math.cos(x) }, channel('curves')), x=100, bounds=Box(x=2 * math.pi)))
def unstack_staggered_tensor(data: Tensor, extrapolation: math.Extrapolation) -> TensorStack: sliced = [] for dim, component in zip(data.shape.spatial.names, data.unstack('vector')): lo_valid, up_valid = extrapolation.valid_outer_faces(dim) slices = {d: slice(0, -1) for d in data.shape.spatial.names} slices[dim] = slice(int(not lo_valid), -int(not up_valid) or None) sliced.append(component[slices]) return math.stack(sliced, channel('vector'))
def stack(fields, dim: Shape, dim_bounds: Box = None): """ Stacks the given `SampledField`s along `dim`. See Also: `concat()`. Args: fields: List of matching `SampledField` instances. dim: Stack dimension as `Shape`. Size is ignored. Returns: `SampledField` matching stacked fields. """ assert all( isinstance(f, SampledField) for f in fields ), f"All fields must be SampledFields of the same type but got {fields}" assert all( isinstance(f, type(fields[0])) for f in fields ), f"All fields must be SampledFields of the same type but got {fields}" if any(f.extrapolation != fields[0].extrapolation for f in fields): raise NotImplementedError("Concatenating extrapolations not supported") if isinstance(fields[0], Grid): values = math.stack([f.values for f in fields], dim) if spatial(dim): if dim_bounds is None: dim_bounds = Box(**{dim.name: len(fields)}) return type(fields[0])(values, extrapolation=fields[0].extrapolation, bounds=fields[0].bounds * dim_bounds) else: return fields[0].with_values(values) elif isinstance(fields[0], PointCloud): elements = geom.stack([f.elements for f in fields], dim=dim) values = math.stack([f.values for f in fields], dim=dim) colors = math.stack([f.color for f in fields], dim=dim) return PointCloud(elements=elements, values=values, color=colors, extrapolation=fields[0].extrapolation, add_overlapping=fields[0]._add_overlapping, bounds=fields[0]._bounds) raise NotImplementedError(type(fields[0]))
def _sample(self, geometry: Geometry, outside_handling='discard') -> Tensor: if geometry == self.elements: return self.values elif isinstance(geometry, GridCell): return self.grid_scatter(geometry.bounds, geometry.resolution, outside_handling) elif isinstance(geometry, GeometryStack): sampled = [self._sample(g) for g in geometry.geometries] return math.stack(sampled, geometry.geometries.shape) else: raise NotImplementedError()
def getpoints(box, resolution): idx_zyx = np.meshgrid(*[ np.linspace(0.5 / dim, 1 - 0.5 / dim, dim) for dim in resolution ], indexing="ij") local_coords = math.expand_dims(math.stack(idx_zyx, axis=-1), 0).astype(np.float32) points = box.local_to_global(local_coords) return CenteredGrid(points, box, name='grid_centers(%s, %s)' % (box, resolution), flags=[SAMPLE_POINTS])
def at_centers(self): rank = self.spatial_rank dims = range(rank) df_dq = [] for d in dims: # z,y,x upper_slices = [(slice(1, None) if i == d else slice(-1)) for i in dims] lower_slices = [(slice(-1) if i == d else slice(-1)) for i in dims] sum = self.staggered[[slice(None)] + upper_slices + [rank - d - 1]] +\ self.staggered[[slice(None)] + lower_slices + [rank - d - 1]] df_dq.append(sum / rank) return math.stack(df_dq[::-1], axis=-1)
def test_stacked_native(self): t0 = math.ones(batch(batch=10) & spatial(x=4, y=3) & channel(vector=2)) tensors = t0.unstack('vector') stacked = math.stack(tensors, channel('vector2')) math.assert_close(stacked, t0) self.assertEqual((10, 4, 3, 2), stacked.native(stacked.shape).shape) self.assertEqual( (4, 3, 2, 10), stacked.native(order=('x', 'y', 'vector2', 'batch')).shape) self.assertEqual( (2, 10, 3, 4), stacked.native(order=('vector2', 'batch', 'y', 'x')).shape ) # this should re-stack since only the stacked dimension position is different
def closest_values(self, points: Geometry): assert 'vector' not in points.shape if 'staggered_direction' in points.shape: points = points.unstack('staggered_direction') channels = [ component.closest_values(p) for p, component in zip(points, self.vector.unstack()) ] else: channels = [ component.closest_values(points) for component in self.vector.unstack() ] return math.stack(channels, channel('vector'))
def _central_diff_nd(field, dims): field = math.pad(field, [[0, 0]] + [[1, 1]] * spatial_rank(field) + [[0, 0]], "symmetric") df_dq = [] for dimension in dims: upper_slices = [(slice(2, None) if i == dimension else slice(1, -1)) for i in dims] lower_slices = [(slice(-2) if i == dimension else slice(1, -1)) for i in dims] diff = field[[slice(None)] + upper_slices + [0]] - field[[slice(None)] + lower_slices + [0]] df_dq.append(diff) return math.stack(df_dq[::-1], axis=-1)
def _forward_diff_nd(field, dims): df_dq = [] for dimension in dims: upper_slices = [(slice(1, None) if i == dimension else slice(None)) for i in dims] lower_slices = [(slice(-1) if i == dimension else slice(None)) for i in dims] diff = field[[slice(None)] + upper_slices] - field[[slice(None)] + lower_slices] padded = math.pad(diff, [[0, 0]] + [([0, 1] if i == dimension else [0, 0]) for i in dims]) df_dq.append(padded) return math.stack(df_dq[::-1], axis=-1)
def _stagger_sample(self, box, resolution): """ Samples this field on a staggered grid. In addition to sampling, extrapolates the field using an occupancy mask generated from the points. :param box: physical dimensions of the grid :param resolution: grid resolution :return: StaggeredGrid """ resolution = np.array(resolution) valid_indices = math.to_int(math.floor(self.sample_points)) valid_indices = math.minimum(math.maximum(0, valid_indices), resolution - 1) # Correct format for math.scatter valid_indices = batch_indices(valid_indices) active_mask = math.scatter(self.sample_points, valid_indices, 1, math.concat([[valid_indices.shape[0]], resolution, [1]], axis=-1), duplicates_handling='any') mask = math.pad(active_mask, [[0, 0]] + [[1, 1]] * self.rank + [[0, 0]], "constant") if isinstance(self.data, (int, float, np.ndarray)): values = math.zeros_like(self.sample_points) + self.data else: values = self.data result = [] ones_1d = math.unstack(math.ones_like(values), axis=-1)[0] staggered_shape = [i + 1 for i in resolution] dx = box.size / resolution dims = range(len(resolution)) for d in dims: staggered_offset = math.stack([(0.5 * dx[i] * ones_1d if i == d else 0.0 * ones_1d) for i in dims], axis=-1) indices = math.to_int(math.floor(self.sample_points + staggered_offset)) valid_indices = math.maximum(0, math.minimum(indices, resolution)) valid_indices = batch_indices(valid_indices) values_d = math.expand_dims(math.unstack(values, axis=-1)[d], axis=-1) result.append(math.scatter(self.sample_points, valid_indices, values_d, [indices.shape[0]] + staggered_shape + [1], duplicates_handling=self.mode)) d_slice = tuple([(slice(0, -2) if i == d else slice(1,-1)) for i in dims]) u_slice = tuple([(slice(2, None) if i == d else slice(1,-1)) for i in dims]) active_mask = math.minimum(mask[(slice(None),) + d_slice + (slice(None),)], active_mask) active_mask = math.minimum(mask[(slice(None),) + u_slice + (slice(None),)], active_mask) staggered_tensor_prep = unstack_staggered_tensor(math.concat(result, axis=-1)) grid_values = StaggeredGrid(staggered_tensor_prep) # Fix values at boundary of liquids (using StaggeredGrid these might not receive a value, so we replace it with a value inside the liquid) grid_values, _ = extrapolate(grid_values, active_mask, voxel_distance=2) return grid_values
def sample_at(self, points): points_rank = math.spatial_rank(points) src_rank = math.spatial_rank(self.location) # --- Expand shapes to format (batch_size, points_dims..., src_dims..., channels) --- points = math.expand_dims(points, axis=-2, number=src_rank) src_points = math.expand_dims(self.location, axis=-2, number=points_rank) src_strength = math.expand_dims(self.strength, axis=-1) src_strength = math.batch_align(src_strength, 0, self.location) src_strength = math.expand_dims(src_strength, axis=-1, number=points_rank) src_axes = tuple(range(-2, -2 - src_rank, -1)) # --- Compute distances and falloff --- distances = points - src_points if self.falloff is not None: raise NotImplementedError() # distances_squared = math.sum(distances ** 2, axis=-1, keepdims=True) # unit_distances = distances / math.sqrt(distances_squared) # strength = src_strength * math.exp(-distances_squared) else: strength = src_strength # --- Compute velocities --- if math.staticshape(points)[-1] == 2: # Curl in 2D dist_1, dist_2 = math.unstack(distances, axis=-1) if GLOBAL_AXIS_ORDER.is_x_first: velocity = strength * math.stack([-dist_2, dist_1], axis=-1) else: velocity = strength * math.stack([dist_2, -dist_1], axis=-1) elif math.staticshape(points)[-1] == 3: # Curl in 3D raise NotImplementedError('not yet implemented') else: raise AssertionError( 'Vector product not available in > 3 dimensions') velocity = math.sum(velocity, axis=src_axes) return velocity
def extrapolate_valid(grid: GridType, valid: GridType, distance_cells=1) -> tuple: """ Extrapolates values of `grid` which are marked by nonzero values in `valid` using `phi.math.extrapolate_valid_values(). If `values` is a StaggeredGrid, its components get extrapolated independently. Args: grid: Grid holding the values for extrapolation valid: Grid (same type as `values`) marking the positions for extrapolation with nonzero values distance_cells: Number of extrapolation steps Returns: grid: Grid with extrapolated values. valid: binary Grid marking all valid values after extrapolation. """ assert isinstance( valid, type(grid)), 'Type of valid Grid must match type of grid.' if isinstance(grid, CenteredGrid): new_values, new_valid = extrapolate_valid_values( grid.values, valid.values, distance_cells) return grid.with_values(new_values), valid.with_values(new_valid) elif isinstance(grid, StaggeredGrid): new_values = [] new_valid = [] for cgrid, cvalid in zip(unstack(grid, 'vector'), unstack(valid, 'vector')): new_tensor, new_mask = extrapolate_valid( cgrid, valid=cvalid, distance_cells=distance_cells) new_values.append(new_tensor.values) new_valid.append(new_mask.values) return grid.with_values(math.stack(new_values, channel(grid))), valid.with_values( math.stack( new_valid, channel(grid))) else: raise NotImplementedError()
def test_tensor_from_tensor(self): ref = math.stack([math.zeros(spatial(x=5)), math.zeros(spatial(x=4))], batch('stack')) for backend in BACKENDS: with backend: tens = math.tensor(ref, convert=False) self.assertEqual(math.NUMPY, math.choose_backend(tens)) self.assertEqual(2, tens.shape.get_size('stack')) self.assertEqual(('stack', 'x'), tens.shape.names) tens = math.tensor(ref) self.assertEqual(backend, math.choose_backend(tens)) self.assertEqual(backend, math.choose_backend(tens.stack[0])) self.assertEqual(backend, math.choose_backend(tens.stack[1])) tens = math.tensor(ref, batch('n1', 'n2')) self.assertEqual(backend, math.choose_backend(tens))