Ejemplo n.º 1
0
    def spatial_derivatives(
        self,
        inputs: Mapping[str, tf.Tensor],
        request: Set[str] = None,
    ) -> Dict[str, tf.Tensor]:
        """See base class."""
        if request is None:
            request = self.equation.all_keys

        result = {}
        for key in request:
            coefficients = self.coefficients[key]

            source = inputs[self.parents[key]]
            if coefficients is None:
                result[key] = source
            else:
                sizes = [stencil.size for stencil in self.stencils[key]]

                key_def = self.equation.key_definitions[key]
                parent_def = self.equation.key_definitions[self.parents[key]]
                shifts = [
                    k - p for p, k in zip(parent_def.offset, key_def.offset)
                ]

                patches = tensor_ops.extract_patches_2d(source, sizes, shifts)
                result[key] = tf.tensordot(coefficients,
                                           patches,
                                           axes=[-1, -1])
                assert result[key].shape[-2:] == source.shape[-2:], (
                    result[key], source)

        return result
Ejemplo n.º 2
0
 def call(self, inputs):
     (kernel, source) = inputs
     coefficients = self.constraint_layer(kernel)
     sizes = [stencil.size for stencil in self.stencils]
     patches = tensor_ops.extract_patches_2d(source, sizes, self.shifts)
     return tf.einsum('bxys,bxys->bxy', coefficients, patches)