def canonical(kernel): '''Sorts a kernel tree into a canonical form.''' if isinstance(kernel, fk.BaseKernel): return kernel.copy() elif isinstance(kernel, fk.MaskKernel): return fk.MaskKernel(kernel.ndim, kernel.active_dimension, canonical(kernel.base_kernel)) elif isinstance(kernel, fk.SumKernel): new_ops = [] for op in kernel.operands: op_canon = canonical(op) if isinstance(op, fk.SumKernel): new_ops += op_canon.operands else: new_ops.append(op_canon) return fk.SumKernel(sorted(new_ops)) elif isinstance(kernel, fk.ProductKernel): new_ops = [] for op in kernel.operands: op_canon = canonical(op) if isinstance(op, fk.ProductKernel): new_ops += op_canon.operands else: new_ops.append(op_canon) return fk.ProductKernel(sorted(new_ops)) else: raise RuntimeError('Unknown kernel class:', kernel.__class__)
def polish_to_kernel(polish_expr): if type(polish_expr) == tuple: if polish_expr[0] == '+': operands = [polish_to_kernel(e) for e in polish_expr[1:]] return fk.SumKernel(operands) elif polish_expr[0] == '*': operands = [polish_to_kernel(e) for e in polish_expr[1:]] return fk.ProductKernel(operands) else: raise RuntimeError('Unknown operator: %s' % polish_expr[0]) else: assert isinstance(polish_expr, fk.Kernel) return polish_expr
def expand(kernel, grammar): result = expand_single_tree(kernel, grammar) if isinstance(kernel, fk.BaseKernel): pass elif isinstance(kernel, fk.MaskKernel): result += [ fk.MaskKernel(kernel.ndim, kernel.active_dimension, e) for e in expand(kernel.base_kernel, grammar) ] elif isinstance(kernel, fk.SumKernel): for i, op in enumerate(kernel.operands): for e in expand(op, grammar): new_ops = kernel.operands[:i] + [e] + kernel.operands[i + 1:] new_ops = [op.copy() for op in new_ops] result.append(fk.SumKernel(new_ops)) elif isinstance(kernel, fk.ProductKernel): for i, op in enumerate(kernel.operands): for e in expand(op, grammar): new_ops = kernel.operands[:i] + [e] + kernel.operands[i + 1:] new_ops = [op.copy() for op in new_ops] result.append(fk.ProductKernel(new_ops)) else: raise RuntimeError('Unknown kernel class:', kernel.__class__) return result