def spatial_derivatives( self, inputs: Mapping[str, tf.Tensor], request: Set[str] = None, ) -> Dict[str, tf.Tensor]: """See base class.""" if request is None: request = self.equation.all_keys result = {} for key in request: coefficients = self.coefficients[key] source = inputs[self.parents[key]] if coefficients is None: result[key] = source else: sizes = [stencil.size for stencil in self.stencils[key]] key_def = self.equation.key_definitions[key] parent_def = self.equation.key_definitions[self.parents[key]] shifts = [ k - p for p, k in zip(parent_def.offset, key_def.offset) ] patches = tensor_ops.extract_patches_2d(source, sizes, shifts) result[key] = tf.tensordot(coefficients, patches, axes=[-1, -1]) assert result[key].shape[-2:] == source.shape[-2:], ( result[key], source) return result
def call(self, inputs): (kernel, source) = inputs coefficients = self.constraint_layer(kernel) sizes = [stencil.size for stencil in self.stencils] patches = tensor_ops.extract_patches_2d(source, sizes, self.shifts) return tf.einsum('bxys,bxys->bxy', coefficients, patches)