예제 #1
0
def canonical(kernel):
    '''Sorts a kernel tree into a canonical form.'''
    if isinstance(kernel, fk.BaseKernel):
        return kernel.copy()
    elif isinstance(kernel, fk.MaskKernel):
        return fk.MaskKernel(kernel.ndim, kernel.active_dimension,
                             canonical(kernel.base_kernel))
    elif isinstance(kernel, fk.SumKernel):
        new_ops = []
        for op in kernel.operands:
            op_canon = canonical(op)
            if isinstance(op, fk.SumKernel):
                new_ops += op_canon.operands
            else:
                new_ops.append(op_canon)
        return fk.SumKernel(sorted(new_ops))
    elif isinstance(kernel, fk.ProductKernel):
        new_ops = []
        for op in kernel.operands:
            op_canon = canonical(op)
            if isinstance(op, fk.ProductKernel):
                new_ops += op_canon.operands
            else:
                new_ops.append(op_canon)
        return fk.ProductKernel(sorted(new_ops))
    else:
        raise RuntimeError('Unknown kernel class:', kernel.__class__)
예제 #2
0
def polish_to_kernel(polish_expr):
    if type(polish_expr) == tuple:
        if polish_expr[0] == '+':
            operands = [polish_to_kernel(e) for e in polish_expr[1:]]
            return fk.SumKernel(operands)
        elif polish_expr[0] == '*':
            operands = [polish_to_kernel(e) for e in polish_expr[1:]]
            return fk.ProductKernel(operands)
        else:
            raise RuntimeError('Unknown operator: %s' % polish_expr[0])
    else:
        assert isinstance(polish_expr, fk.Kernel)
        return polish_expr
예제 #3
0
def expand(kernel, grammar):
    result = expand_single_tree(kernel, grammar)
    if isinstance(kernel, fk.BaseKernel):
        pass
    elif isinstance(kernel, fk.MaskKernel):
        result += [
            fk.MaskKernel(kernel.ndim, kernel.active_dimension, e)
            for e in expand(kernel.base_kernel, grammar)
        ]
    elif isinstance(kernel, fk.SumKernel):
        for i, op in enumerate(kernel.operands):
            for e in expand(op, grammar):
                new_ops = kernel.operands[:i] + [e] + kernel.operands[i + 1:]
                new_ops = [op.copy() for op in new_ops]
                result.append(fk.SumKernel(new_ops))
    elif isinstance(kernel, fk.ProductKernel):
        for i, op in enumerate(kernel.operands):
            for e in expand(op, grammar):
                new_ops = kernel.operands[:i] + [e] + kernel.operands[i + 1:]
                new_ops = [op.copy() for op in new_ops]
                result.append(fk.ProductKernel(new_ops))
    else:
        raise RuntimeError('Unknown kernel class:', kernel.__class__)
    return result