def test_plain_input_output(self): op_proto = framework_pb2.OpProto() op_proto.type = "test" ipt = op_proto.inputs.add() ipt.name = "X" ipt.comment = "not matter" ipt = op_proto.inputs.add() ipt.name = "Y" ipt.comment = "not matter" opt = op_proto.outputs.add() opt.name = "Z" opt.comment = "not matter" op_proto.comment = "not matter" self.assertTrue(op_proto.IsInitialized()) method = op.OpDescCreationMethod(op_proto) output = method(X="a", Y="b", Z="c") expected = framework_pb2.OpDesc() expected.type = "test" ipt_0 = expected.inputs.add() ipt_0.parameter = "X" ipt_0.arguments.extend(["a"]) ipt_1 = expected.inputs.add() ipt_1.parameter = 'Y' ipt_1.arguments.extend(['b']) opt = expected.outputs.add() opt.parameter = "Z" opt.arguments.extend(["c"]) self.assertEqual(expected, output)
def OpDesc(self, op_type, input_key_vals, output_key_vals, attrs): """ add OpDesc """ desc = framework_pb2.OpDesc() desc.type = op_type desc.inputs.extend(self.OpDescVars(*input_key_vals)) desc.outputs.extend(self.OpDescVars(*output_key_vals)) desc.attrs.extend(self.OpDescAttrs(attrs)) self.op_descs.append(desc) return desc
def OpDesc(self, name, input_val_keys=None, output_val_keys=None, attrs=None): """ add OpDesc """ desc = framework_pb2.OpDesc() desc.type = name if input_val_keys is not None: desc.inputs.extend(self.OpDescVars(*input_val_keys)) if output_val_keys is not None: desc.outputs.extend(self.OpDescVars(*output_val_keys)) if attrs is not None: desc.attrs.extend(self.OpDescAttrs(attrs)) self.op_descs.append(desc) return desc
def __call__(self, *args, **kwargs): """ Convert user's input to OpDesc. Only keyword arguments are supported. :return: The OpDesc based on user input. :rtype: op_desc_pb2.OpDesc """ if len(args) != 0: raise ValueError("Only keyword arguments are supported.") op_desc = framework_pb2.OpDesc() for input_parameter in self.__op_proto__.inputs: input_arguments = kwargs.get(input_parameter.name, []) if is_str(input_arguments): input_arguments = [input_arguments] if not input_parameter.duplicable and len(input_arguments) > 1: raise ValueError( "Input %s expects only one input, but %d are given." % (input_parameter.name, len(input_arguments))) ipt = op_desc.inputs.add() ipt.parameter = input_parameter.name ipt.arguments.extend(input_arguments) for output_parameter in self.__op_proto__.outputs: output_arguments = kwargs.get(output_parameter.name, []) if is_str(output_arguments): output_arguments = [output_arguments] if not output_parameter.duplicable and len(output_arguments) > 1: raise ValueError( "Output %s expects only one output, but %d are given." % (output_parameter.name, len(output_arguments))) out = op_desc.outputs.add() out.parameter = output_parameter.name out.arguments.extend(output_arguments) # Types op_desc.type = self.__op_proto__.type # Attrs for attr in self.__op_proto__.attrs: if attr.generated: continue user_defined_attr = kwargs.get(attr.name, None) if user_defined_attr is not None: new_attr = op_desc.attrs.add() new_attr.name = attr.name new_attr.type = attr.type if isinstance(user_defined_attr, np.ndarray): user_defined_attr = user_defined_attr.tolist() if attr.type == framework_pb2.INT: new_attr.i = user_defined_attr elif attr.type == framework_pb2.FLOAT: new_attr.f = user_defined_attr elif attr.type == framework_pb2.LONG: new_attr.l = user_defined_attr elif attr.type == framework_pb2.STRING: new_attr.s = user_defined_attr elif attr.type == framework_pb2.BOOLEAN: new_attr.b = user_defined_attr elif attr.type == framework_pb2.INTS: new_attr.ints.extend(user_defined_attr) elif attr.type == framework_pb2.FLOATS: new_attr.floats.extend(user_defined_attr) elif attr.type == framework_pb2.STRINGS: new_attr.strings.extend(user_defined_attr) elif attr.type == framework_pb2.BOOLEANS: new_attr.bools.extend(user_defined_attr) elif attr.type == framework_pb2.LONGS: new_attr.longs.extend(user_defined_attr) elif attr.type == framework_pb2.INT_PAIRS: for p in user_defined_attr: pair = new_attr.int_pairs.add() pair.first = p[0] pair.second = p[1] else: raise NotImplementedError( "A not supported attribute type: %s." % (str(attr.type))) return op_desc
def test_multiple_input_plain_output(self): op_proto = framework_pb2.OpProto() op_proto.type = "fc" ipt = op_proto.inputs.add() ipt.name = "X" ipt.comment = "" ipt.duplicable = True ipt = op_proto.inputs.add() ipt.name = "W" ipt.comment = "" ipt.duplicable = True ipt = op_proto.inputs.add() ipt.name = "b" ipt.comment = "" out = op_proto.outputs.add() out.name = "Y" out.comment = "" op_proto.comment = "" self.assertTrue(op_proto.IsInitialized()) method = op.OpDescCreationMethod(op_proto) generated1 = method(X="x", W="w", b="b", Y="y") expected1 = framework_pb2.OpDesc() tmp = expected1.inputs.add() tmp.parameter = "X" tmp.arguments.extend(['x']) tmp = expected1.inputs.add() tmp.parameter = 'W' tmp.arguments.extend(['w']) tmp = expected1.inputs.add() tmp.parameter = 'b' tmp.arguments.extend(['b']) tmp = expected1.outputs.add() tmp.parameter = 'Y' tmp.arguments.extend(['y']) expected1.type = 'fc' self.assertEqual(expected1, generated1) generated2 = method(X=['x1', 'x2', 'x3'], b='b', W=['w1', 'w2', 'w3'], Y='y') expected2 = framework_pb2.OpDesc() tmp = expected2.inputs.add() tmp.parameter = "X" tmp.arguments.extend(['x1', 'x2', 'x3']) tmp = expected2.inputs.add() tmp.parameter = 'W' tmp.arguments.extend(['w1', 'w2', 'w3']) tmp = expected2.inputs.add() tmp.parameter = 'b' tmp.arguments.extend(['b']) tmp = expected2.outputs.add() tmp.parameter = 'Y' tmp.arguments.extend(['y']) expected2.type = 'fc' self.assertEqual(expected2, generated2)
def test_attrs(self): op_proto = framework_pb2.OpProto() op_proto.type = "test" ipt = op_proto.inputs.add() ipt.name = 'X' ipt.comment = "" def __add_attr__(name, type): attr = op_proto.attrs.add() attr.name = name attr.comment = "" attr.type = type __add_attr__("int_attr", framework_pb2.INT) __add_attr__("float_attr", framework_pb2.FLOAT) __add_attr__("string_attr", framework_pb2.STRING) __add_attr__("ints_attr", framework_pb2.INTS) __add_attr__("floats_attr", framework_pb2.FLOATS) __add_attr__("strings_attr", framework_pb2.STRINGS) op_proto.comment = "" self.assertTrue(op_proto.IsInitialized()) method = op.OpDescCreationMethod(op_proto) generated = method(X="a", int_attr=10, float_attr=3.2, string_attr="test_str", ints_attr=[0, 1, 2, 3, 4], floats_attr=[0.2, 3.2, 4.5], strings_attr=["a", "b", "c"]) expected = framework_pb2.OpDesc() expected.type = "test" ipt = expected.inputs.add() ipt.parameter = "X" ipt.arguments.extend(['a']) attr = expected.attrs.add() attr.name = "int_attr" attr.type = framework_pb2.INT attr.i = 10 attr = expected.attrs.add() attr.name = "float_attr" attr.type = framework_pb2.FLOAT attr.f = 3.2 attr = expected.attrs.add() attr.name = "string_attr" attr.type = framework_pb2.STRING attr.s = "test_str" attr = expected.attrs.add() attr.name = "ints_attr" attr.type = framework_pb2.INTS attr.ints.extend([0, 1, 2, 3, 4]) attr = expected.attrs.add() attr.name = 'floats_attr' attr.type = framework_pb2.FLOATS attr.floats.extend([0.2, 3.2, 4.5]) attr = expected.attrs.add() attr.name = 'strings_attr' attr.type = framework_pb2.STRINGS attr.strings.extend(['a', 'b', 'c']) self.assertEqual(expected, generated)