예제 #1
0
        def conv(scope, operator, container):
            X = operator.inputs[0]
            out = operator.outputs
            op = operator.raw_operator

            C = op.cluster_centers_
            C2 = row_norms(C, squared=True)

            N = X.type.shape[0]
            zeros = np.zeros((N, ))

            rs = OnnxReduceSumSquare(X, axes=[1], keepdims=1)
            z = OnnxAdd(rs, OnnxGemm(X, C, zeros, alpha=-2., transB=1))
            y2 = OnnxAdd(C2, z)
            lo = OnnxArgMin(y2, axis=1, keepdims=0, output_names=out[:1])
            y2s = OnnxSqrt(y2, output_names=out[1:])

            lo.add_to(scope, container)
            y2s.add_to(scope, container)
        def conv(scope, operator, container):
            X = operator.inputs[0]
            out = operator.outputs
            op = operator.raw_operator
            dtype = guess_numpy_type(X.type)

            C = op.cluster_centers_
            C2 = row_norms(C, squared=True).astype(dtype)
            C = C.astype(dtype)

            rs = OnnxReduceSumSquare(X,
                                     axes=[1],
                                     keepdims=1,
                                     op_version=container.target_opset)

            N = X.type.shape[0]
            if isinstance(N, int):
                zeros = np.zeros((N, ))
            else:
                zeros = OnnxMul(rs,
                                np.array([0], dtype=np.float32),
                                op_version=container.target_opset)

            z = OnnxAdd(rs,
                        OnnxGemm(X,
                                 C,
                                 zeros,
                                 alpha=-2.,
                                 transB=1,
                                 op_version=container.target_opset),
                        op_version=container.target_opset)
            y2 = OnnxAdd(C2, z, op_version=container.target_opset)
            lo = OnnxArgMin(y2,
                            axis=1,
                            keepdims=0,
                            output_names=out[:1],
                            op_version=container.target_opset)
            y2s = OnnxSqrt(y2,
                           output_names=out[1:],
                           op_version=container.target_opset)

            lo.add_to(scope, container)
            y2s.add_to(scope, container)