コード例 #1
0
    def test_plain_input_output(self):
        op_proto = framework_pb2.OpProto()
        op_proto.type = "test"
        ipt = op_proto.inputs.add()
        ipt.name = "X"
        ipt.comment = "not matter"

        ipt = op_proto.inputs.add()
        ipt.name = "Y"
        ipt.comment = "not matter"

        opt = op_proto.outputs.add()
        opt.name = "Z"
        opt.comment = "not matter"

        op_proto.comment = "not matter"

        self.assertTrue(op_proto.IsInitialized())

        method = op.OpDescCreationMethod(op_proto)
        output = method(X="a", Y="b", Z="c")
        expected = framework_pb2.OpDesc()
        expected.type = "test"
        ipt_0 = expected.inputs.add()
        ipt_0.parameter = "X"
        ipt_0.arguments.extend(["a"])
        ipt_1 = expected.inputs.add()
        ipt_1.parameter = 'Y'
        ipt_1.arguments.extend(['b'])
        opt = expected.outputs.add()
        opt.parameter = "Z"
        opt.arguments.extend(["c"])

        self.assertEqual(expected, output)
コード例 #2
0
ファイル: writer.py プロジェクト: zhangwernui/X2Paddle
    def OpDesc(self, op_type, input_key_vals, output_key_vals, attrs):
        """
        add OpDesc
        """

        desc = framework_pb2.OpDesc()
        desc.type = op_type
        desc.inputs.extend(self.OpDescVars(*input_key_vals))
        desc.outputs.extend(self.OpDescVars(*output_key_vals))
        desc.attrs.extend(self.OpDescAttrs(attrs))
        self.op_descs.append(desc)
        return desc
コード例 #3
0
ファイル: writer.py プロジェクト: shanyi15/X2Paddle
    def OpDesc(self,
               name,
               input_val_keys=None,
               output_val_keys=None,
               attrs=None):
        """
        add OpDesc
        """

        desc = framework_pb2.OpDesc()
        desc.type = name
        if input_val_keys is not None:
            desc.inputs.extend(self.OpDescVars(*input_val_keys))
        if output_val_keys is not None:
            desc.outputs.extend(self.OpDescVars(*output_val_keys))
        if attrs is not None:
            desc.attrs.extend(self.OpDescAttrs(attrs))
        self.op_descs.append(desc)
        return desc
コード例 #4
0
    def __call__(self, *args, **kwargs):
        """
        Convert user's input to OpDesc. Only keyword arguments are supported.
        :return: The OpDesc based on user input.
        :rtype: op_desc_pb2.OpDesc
        """
        if len(args) != 0:
            raise ValueError("Only keyword arguments are supported.")
        op_desc = framework_pb2.OpDesc()
        for input_parameter in self.__op_proto__.inputs:
            input_arguments = kwargs.get(input_parameter.name, [])
            if is_str(input_arguments):
                input_arguments = [input_arguments]

            if not input_parameter.duplicable and len(input_arguments) > 1:
                raise ValueError(
                    "Input %s expects only one input, but %d are given." %
                    (input_parameter.name, len(input_arguments)))

            ipt = op_desc.inputs.add()
            ipt.parameter = input_parameter.name
            ipt.arguments.extend(input_arguments)

        for output_parameter in self.__op_proto__.outputs:
            output_arguments = kwargs.get(output_parameter.name, [])
            if is_str(output_arguments):
                output_arguments = [output_arguments]

            if not output_parameter.duplicable and len(output_arguments) > 1:
                raise ValueError(
                    "Output %s expects only one output, but %d are given." %
                    (output_parameter.name, len(output_arguments)))

            out = op_desc.outputs.add()
            out.parameter = output_parameter.name
            out.arguments.extend(output_arguments)

        # Types
        op_desc.type = self.__op_proto__.type

        # Attrs
        for attr in self.__op_proto__.attrs:
            if attr.generated:
                continue
            user_defined_attr = kwargs.get(attr.name, None)
            if user_defined_attr is not None:
                new_attr = op_desc.attrs.add()
                new_attr.name = attr.name
                new_attr.type = attr.type
                if isinstance(user_defined_attr, np.ndarray):
                    user_defined_attr = user_defined_attr.tolist()
                if attr.type == framework_pb2.INT:
                    new_attr.i = user_defined_attr
                elif attr.type == framework_pb2.FLOAT:
                    new_attr.f = user_defined_attr
                elif attr.type == framework_pb2.LONG:
                    new_attr.l = user_defined_attr
                elif attr.type == framework_pb2.STRING:
                    new_attr.s = user_defined_attr
                elif attr.type == framework_pb2.BOOLEAN:
                    new_attr.b = user_defined_attr
                elif attr.type == framework_pb2.INTS:
                    new_attr.ints.extend(user_defined_attr)
                elif attr.type == framework_pb2.FLOATS:
                    new_attr.floats.extend(user_defined_attr)
                elif attr.type == framework_pb2.STRINGS:
                    new_attr.strings.extend(user_defined_attr)
                elif attr.type == framework_pb2.BOOLEANS:
                    new_attr.bools.extend(user_defined_attr)
                elif attr.type == framework_pb2.LONGS:
                    new_attr.longs.extend(user_defined_attr)
                elif attr.type == framework_pb2.INT_PAIRS:
                    for p in user_defined_attr:
                        pair = new_attr.int_pairs.add()
                        pair.first = p[0]
                        pair.second = p[1]
                else:
                    raise NotImplementedError(
                        "A not supported attribute type: %s." %
                        (str(attr.type)))

        return op_desc
コード例 #5
0
    def test_multiple_input_plain_output(self):
        op_proto = framework_pb2.OpProto()
        op_proto.type = "fc"
        ipt = op_proto.inputs.add()
        ipt.name = "X"
        ipt.comment = ""
        ipt.duplicable = True

        ipt = op_proto.inputs.add()
        ipt.name = "W"
        ipt.comment = ""
        ipt.duplicable = True

        ipt = op_proto.inputs.add()
        ipt.name = "b"
        ipt.comment = ""

        out = op_proto.outputs.add()
        out.name = "Y"
        out.comment = ""

        op_proto.comment = ""
        self.assertTrue(op_proto.IsInitialized())
        method = op.OpDescCreationMethod(op_proto)

        generated1 = method(X="x", W="w", b="b", Y="y")
        expected1 = framework_pb2.OpDesc()
        tmp = expected1.inputs.add()
        tmp.parameter = "X"
        tmp.arguments.extend(['x'])

        tmp = expected1.inputs.add()
        tmp.parameter = 'W'
        tmp.arguments.extend(['w'])

        tmp = expected1.inputs.add()
        tmp.parameter = 'b'
        tmp.arguments.extend(['b'])

        tmp = expected1.outputs.add()
        tmp.parameter = 'Y'
        tmp.arguments.extend(['y'])
        expected1.type = 'fc'
        self.assertEqual(expected1, generated1)

        generated2 = method(X=['x1', 'x2', 'x3'],
                            b='b',
                            W=['w1', 'w2', 'w3'],
                            Y='y')
        expected2 = framework_pb2.OpDesc()

        tmp = expected2.inputs.add()
        tmp.parameter = "X"
        tmp.arguments.extend(['x1', 'x2', 'x3'])

        tmp = expected2.inputs.add()
        tmp.parameter = 'W'
        tmp.arguments.extend(['w1', 'w2', 'w3'])

        tmp = expected2.inputs.add()
        tmp.parameter = 'b'
        tmp.arguments.extend(['b'])

        tmp = expected2.outputs.add()
        tmp.parameter = 'Y'
        tmp.arguments.extend(['y'])

        expected2.type = 'fc'
        self.assertEqual(expected2, generated2)
コード例 #6
0
    def test_attrs(self):
        op_proto = framework_pb2.OpProto()
        op_proto.type = "test"
        ipt = op_proto.inputs.add()
        ipt.name = 'X'
        ipt.comment = ""

        def __add_attr__(name, type):
            attr = op_proto.attrs.add()
            attr.name = name
            attr.comment = ""
            attr.type = type

        __add_attr__("int_attr", framework_pb2.INT)
        __add_attr__("float_attr", framework_pb2.FLOAT)
        __add_attr__("string_attr", framework_pb2.STRING)
        __add_attr__("ints_attr", framework_pb2.INTS)
        __add_attr__("floats_attr", framework_pb2.FLOATS)
        __add_attr__("strings_attr", framework_pb2.STRINGS)

        op_proto.comment = ""
        self.assertTrue(op_proto.IsInitialized())

        method = op.OpDescCreationMethod(op_proto)

        generated = method(X="a",
                           int_attr=10,
                           float_attr=3.2,
                           string_attr="test_str",
                           ints_attr=[0, 1, 2, 3, 4],
                           floats_attr=[0.2, 3.2, 4.5],
                           strings_attr=["a", "b", "c"])

        expected = framework_pb2.OpDesc()
        expected.type = "test"

        ipt = expected.inputs.add()
        ipt.parameter = "X"
        ipt.arguments.extend(['a'])

        attr = expected.attrs.add()
        attr.name = "int_attr"
        attr.type = framework_pb2.INT
        attr.i = 10

        attr = expected.attrs.add()
        attr.name = "float_attr"
        attr.type = framework_pb2.FLOAT
        attr.f = 3.2

        attr = expected.attrs.add()
        attr.name = "string_attr"
        attr.type = framework_pb2.STRING
        attr.s = "test_str"

        attr = expected.attrs.add()
        attr.name = "ints_attr"
        attr.type = framework_pb2.INTS
        attr.ints.extend([0, 1, 2, 3, 4])

        attr = expected.attrs.add()
        attr.name = 'floats_attr'
        attr.type = framework_pb2.FLOATS
        attr.floats.extend([0.2, 3.2, 4.5])

        attr = expected.attrs.add()
        attr.name = 'strings_attr'
        attr.type = framework_pb2.STRINGS
        attr.strings.extend(['a', 'b', 'c'])

        self.assertEqual(expected, generated)