def Attr(self, attr_name, attr_value, attr_type_name=None): r"""Set value of op's attribute. Args: attr_name (str): attribute name of op attr_value (Any): attribute value of op Raises: ValueError: raised when value is not idential to op's attribute type. Returns: [type]: [description] """ if attr_type_name is not None: print( """WARNING: Argument 'attr_type_name' of UserOpConfBuilder.Attr has been deprecated. Please remove it. For instance: - .Attr("out_num", out_num, "AttrTypeInt64") + .Attr("out_num", out_num) """ ) print(traceback.format_stack()[-2]) attribute = user_op_attr_cfg.AttrValue() assert isinstance(attr_name, str) assert self._op_type_name is not None attr_type = oneflow_api.GetUserOpAttrType(self._op_type_name, attr_name) if attr_type == user_op_attr_cfg.kAtInt32: assert isinstance(attr_value, int) attribute.set_at_int32(attr_value) elif attr_type == user_op_attr_cfg.kAtInt64: assert isinstance(attr_value, int) attribute.set_at_int64(attr_value) elif attr_type == user_op_attr_cfg.kAtBool: assert isinstance(attr_value, bool) attribute.set_at_bool(attr_value) elif attr_type == user_op_attr_cfg.kAtFloat: assert isinstance(attr_value, (float, int)) attribute.set_at_float(attr_value) elif attr_type == user_op_attr_cfg.kAtDouble: assert isinstance(attr_value, (float, int)) attribute.set_at_double(attr_value) elif attr_type == user_op_attr_cfg.kAtString: assert isinstance(attr_value, str) attribute.set_at_string(attr_value) elif attr_type == user_op_attr_cfg.kAtShape: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_shape = attribute.mutable_at_shape() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_shape.add_dim(x) elif attr_type == user_op_attr_cfg.kAtDataType: assert attr_value in oneflow.dtypes() attr_value = oneflow_api.deprecated.GetProtoDtype4OfDtype(attr_value) assert isinstance(attr_value, int) attribute.set_at_data_type(data_type_cfg.DataType(attr_value)) elif attr_type == user_op_attr_cfg.kAtListInt32: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_list_int32.add_val(x) elif attr_type == user_op_attr_cfg.kAtListInt64: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_list_int64.add_val(x) elif attr_type == user_op_attr_cfg.kAtListFloat: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_float = attribute.mutable_at_list_float() for x in attr_value: assert isinstance(x, (float, int)) attribute_mutable_at_list_float.add_val(x) elif attr_type == user_op_attr_cfg.kAtListDataType: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type() for x in attr_value: assert x in oneflow.dtypes() x = oneflow_api.deprecated.GetProtoDtype4OfDtype(x) assert isinstance(x, int) attribute_mutable_at_list_data_type.add_val(data_type_cfg.DataType(x)) elif attr_type == user_op_attr_cfg.kAtListShape: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_shape = ( attribute.mutable_at_list_shape().mutable_val() ) for x in attr_value: assert isinstance(x, (tuple, list)) shape = shape_cfg.ShapeProto() for dim in x: assert isinstance(dim, int) shape.add_dim(dim) attribute_mutable_at_list_shape.Add().CopyFrom(shape) elif attr_type == user_op_attr_cfg.kAtListString: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_string = attribute.mutable_at_list_string() for x in attr_value: assert isinstance(x, str) attribute_mutable_at_list_string.add_val(x) else: raise ValueError("Invalid op attribute type {}".format(attr_type)) self._builder.attr(attr_name, attribute) return self
def Attr(self, attr_name, attr_value, attr_type_name=None): r"""Set value of op's attribute. Args: attr_name (str): attribute name of op attr_value (Any): attribute value of op Raises: ValueError: raised when value is not idential to op's attribute type. Returns: [type]: [description] """ if attr_type_name != None: print( """WARNING: Argument 'attr_type_name' of UserOpConfBuilder.Attr has been deprecated. Please remove it. For instance: - .Attr("out_num", out_num, "AttrTypeInt64") + .Attr("out_num", out_num) """) print(traceback.format_stack()[-2]) attribute = attr_value_pb.AttrValue() assert isinstance(attr_name, str) attr_type = oneflow_api.GetUserOpAttrType( self.user_op_.op_conf_.user_conf.op_type_name, attr_name) if attr_type == attr_value_pb.kAtInt32: assert isinstance(attr_value, int) attribute.at_int32 = attr_value elif attr_type == attr_value_pb.kAtInt64: assert isinstance(attr_value, int) attribute.at_int64 = attr_value elif attr_type == attr_value_pb.kAtBool: assert isinstance(attr_value, bool) attribute.at_bool = attr_value elif attr_type == attr_value_pb.kAtFloat: assert isinstance(attr_value, float) attribute.at_float = attr_value elif attr_type == attr_value_pb.kAtDouble: assert isinstance(attr_value, float) attribute.at_double = attr_value elif attr_type == attr_value_pb.kAtString: assert isinstance(attr_value, str) attribute.at_string = attr_value elif attr_type == attr_value_pb.kAtShape: assert isinstance(attr_value, (tuple, list)) assert all(isinstance(x, int) for x in attr_value) attribute.at_shape.dim[:] = list(attr_value) elif attr_type == attr_value_pb.kAtDataType: assert (isinstance(attr_value.oneflow_proto_dtype, int) and attr_value in oneflow.dtypes()) attribute.at_data_type = attr_value.oneflow_proto_dtype elif attr_type == attr_value_pb.kAtListInt32: assert isinstance(attr_value, (tuple, list)) assert all(isinstance(x, int) for x in attr_value) attribute.at_list_int32.val[:] = list(attr_value) elif attr_type == attr_value_pb.kAtListInt64: assert isinstance(attr_value, (tuple, list)) assert all(isinstance(x, int) for x in attr_value) attribute.at_list_int64.val[:] = list(attr_value) elif attr_type == attr_value_pb.kAtListFloat: assert isinstance(attr_value, (tuple, list)) assert all(isinstance(x, float) for x in attr_value) attribute.at_list_float.val[:] = list(attr_value) elif attr_type == attr_value_pb.kAtListDataType: assert isinstance(attr_value, (tuple, list)) assert all( isinstance(x.oneflow_proto_dtype, int) and x in oneflow.dtypes() for x in attr_value) attribute.at_list_data_type.val[:] = list( [x.oneflow_proto_dtype for x in attr_value]) elif attr_type == attr_value_pb.kAtListShape: assert isinstance(attr_value, (tuple, list)) assert all( isinstance(x, tuple) or isinstance(x, list) for x in attr_value) for i in range(len(attr_value)): shape = shape_util.ShapeProto() shape.dim[:] = list(attr_value[i]) attribute.at_list_shape.val.append(shape) elif attr_type == attr_value_pb.kAtListString: assert isinstance(attr_value, (tuple, list)) assert all(isinstance(x, str) for x in attr_value) attribute.at_list_string.val[:] = list(attr_value) else: raise ValueError("Invalid op attribute type {}".format(attr_type)) self.user_op_.op_conf_.user_conf.attr[attr_name].CopyFrom(attribute) return self
def convert_to_user_attr_value(op_type_name, attr_name, attr_value): attribute = user_op_attr_cfg.AttrValue() assert isinstance(attr_name, str) attr_type = oneflow_api.GetUserOpAttrType(op_type_name, attr_name) if attr_type == user_op_attr_cfg.kAtInt32: assert isinstance(attr_value, int) attribute.set_at_int32(attr_value) elif attr_type == user_op_attr_cfg.kAtInt64: assert isinstance(attr_value, int) attribute.set_at_int64(attr_value) elif attr_type == user_op_attr_cfg.kAtBool: assert isinstance(attr_value, bool) attribute.set_at_bool(attr_value) elif attr_type == user_op_attr_cfg.kAtFloat: assert isinstance(attr_value, (float, int)) attribute.set_at_float(attr_value) elif attr_type == user_op_attr_cfg.kAtDouble: assert isinstance(attr_value, (float, int)) attribute.set_at_double(attr_value) elif attr_type == user_op_attr_cfg.kAtString: assert isinstance(attr_value, str) attribute.set_at_string(attr_value) elif attr_type == user_op_attr_cfg.kAtShape: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_shape = attribute.mutable_at_shape() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_shape.add_dim(x) elif attr_type == user_op_attr_cfg.kAtDataType: assert attr_value in oneflow.dtypes() attr_value = oneflow_api.deprecated.GetProtoDtype4OfDtype(attr_value) assert isinstance(attr_value, int) attribute.set_at_data_type(data_type_cfg.DataType(attr_value)) elif attr_type == user_op_attr_cfg.kAtListInt32: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_list_int32.add_val(x) elif attr_type == user_op_attr_cfg.kAtListInt64: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_list_int64.add_val(x) elif attr_type == user_op_attr_cfg.kAtListFloat: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_float = attribute.mutable_at_list_float() for x in attr_value: assert isinstance(x, (float, int)) attribute_mutable_at_list_float.add_val(x) elif attr_type == user_op_attr_cfg.kAtListDataType: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type() for x in attr_value: assert x in oneflow.dtypes() x = oneflow_api.deprecated.GetProtoDtype4OfDtype(x) assert isinstance(x, int) attribute_mutable_at_list_data_type.add_val(data_type_cfg.DataType(x)) elif attr_type == user_op_attr_cfg.kAtListShape: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_shape = ( attribute.mutable_at_list_shape().mutable_val() ) for x in attr_value: assert isinstance(x, (tuple, list)) shape = shape_cfg.ShapeProto() for dim in x: assert isinstance(dim, int) shape.add_dim(dim) attribute_mutable_at_list_shape.Add().CopyFrom(shape) elif attr_type == user_op_attr_cfg.kAtListString: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_string = attribute.mutable_at_list_string() for x in attr_value: assert isinstance(x, str) attribute_mutable_at_list_string.add_val(x) else: raise ValueError("Invalid op attribute type {}".format(attr_type)) return attribute