def test_is_prod(self): kern = self.se_0 cov = Covariance(kern) self.assertFalse(cov.is_prod()) kern = self.se_0 + self.se_0 cov = Covariance(kern) self.assertFalse(cov.is_prod()) kern = self.se_0 * self.se_0 cov = Covariance(kern) self.assertTrue(cov.is_prod())
def expand_full_kernel(self, covariance: Covariance) -> List[Covariance]: """Expand full kernel. :param covariance: Covariance to expand. :return: """ result = self.expand_single_kernel(covariance) if covariance is None: pass elif not covariance.is_base(): for i, operand in enumerate(covariance.raw_kernel.parts): covariance_operand = Covariance(operand) for e in self.expand_full_kernel(covariance_operand): new_operands = covariance.raw_kernel.parts[:i] + [e.raw_kernel] \ + covariance.raw_kernel.parts[i + 1:] new_operands = [op.copy() for op in new_operands] if covariance.is_prod(): prod_kern = new_operands[0] for part in new_operands[1:]: prod_kern *= part result.append(Covariance(prod_kern)) elif covariance.is_sum(): prod_kern = new_operands[0] for part in new_operands[1:]: prod_kern += part result.append(Covariance(prod_kern)) else: raise TypeError( f'Unknown combination kernel class {covariance.__class__.__name__}' ) return result