Ejemplo n.º 1
0
    def create_custom_TopK(self):
        # Defining SSA TopK Op
        @register_op(doc_str="Custom TopK Layer", is_custom_op=True)
        class custom_topk(Operation):
            input_spec = InputSpec(
                x=TensorInputType(),
                k=IntInputType(const=True, optional=True),
                axis=IntInputType(const=True, optional=True),
                sorted=BoolInputType(const=True, optional=True),
            )

            bindings = {
                "class_name": "TopK",
                "input_order": ["x"],
                "parameters": ["k", "axis", "sorted"],
                "description": "Top K Custom layer",
            }

            def default_inputs(self):
                return DefaultInputs(
                    k=1,
                    axis=-1,
                    sorted=False,
                    )

            def __init__(self, **kwargs):
                super(custom_topk, self).__init__(**kwargs)

            def type_inference(self):
                x_type = self.x.dtype
                x_shape = self.x.shape
                k = self.k.val
                axis = self.axis.val

                if not is_symbolic(x_shape[axis]) and k > x_shape[axis]:
                    msg = "K={} is greater than size of the given axis={}"
                    raise ValueError(msg.format(k, axis))

                ret_shape = list(x_shape)
                ret_shape[axis] = k
                return types.tensor(x_type, ret_shape), types.tensor(types.int32, ret_shape)

        # TODO: rdar://61241807 ([MIL] [Polish] Custom layer operator documentation)
        # Following logging is to ensure testing of TopK implemented in tf converter
        # default path is testing with appropriate conversion function
        # Log default tf topk
        default_tf_topk = _TF_OPS_REGISTRY.get("TopKV2", None)

        # Override TopK op with override=True flag
        @register_tf_op(tf_alias=["TopKV2"], override=True)
        def CustomTopK(context, node):
            x = context[node.inputs[0]]
            k = context[node.inputs[1]]
            sorted = node.attr.get("sorted", False)
            x = mb.custom_topk(x=x, k=k.val, axis=-1, sorted=sorted, name=node.name)
            context.add(node.name, x)

        yield

        _TF_OPS_REGISTRY["TopKV2"] = default_tf_topk
Ejemplo n.º 2
0
    def create_custom_selu(self):
        default_selu = _TF_OPS_REGISTRY.get("Selu", None)

        @register_tf_op(tf_alias=[], override=True)
        def Selu(context, node):
            x = context[node.inputs[0]]
            alpha = 1.6732631921768188
            lamda = 1.0507010221481323
            out_elu = mb.elu(x=x, alpha=alpha)
            out = mb.mul(x=out_elu, y=lamda, name=node.name)
            context.add(node.name, out)

        yield

        _TF_OPS_REGISTRY["Selu"] = default_selu
Ejemplo n.º 3
0
                    ), "Custom Layer class name mis-match"
            assert (transpose_a == layers[-1].custom.parameters["transpose_x"].
                    boolValue), "Incorrect parameter value k"
            assert (transpose_b == layers[-1].custom.parameters["transpose_y"].
                    boolValue), "Incorrect parameter value k"
            assert (a_is_sparse == layers[-1].custom.parameters["x_is_sparse"].
                    boolValue), "Incorrect parameter value k"
            assert (b_is_sparse == layers[-1].custom.parameters["y_is_sparse"].
                    boolValue), "Incorrect parameter value k"


# TODO: rdar://61241807 ([MIL] [Polish] Custom layer operator documentation)
# Following logging is to ensure testing of TopK implemented in tf converter
# default path is testing with appropriate conversion function
# Log default tf topk
default_tf_topk = _TF_OPS_REGISTRY.get("TopKV2", None)


# Override TopK op with override=True flag
@register_tf_op(tf_alias=["TopKV2"], override=True)
def CustomTopK(context, node):
    x = context[node.inputs[0]]
    k = context[node.inputs[1]]
    sorted = node.attr.get("sorted", False)
    x = mb.custom_topk(x=x, k=k.val, axis=-1, sorted=sorted, name=node.name)
    context.add(node.name, x)


# Custom TF TopK
custom_tf_topk = _TF_OPS_REGISTRY["TopKV2"]