Beispiel #1
0
    def Attr(self, attr_name, attr_value, attr_type_name=None):
        r"""Set value of op's attribute.

        Args:
            attr_name (str): attribute name of op
            attr_value (Any): attribute value of op

        Raises:
            ValueError: raised when value is not idential to op's attribute type.

        Returns:
            [type]: [description]
        """
        if attr_type_name is not None:
            print(
                """WARNING: Argument 'attr_type_name' of UserOpConfBuilder.Attr has been deprecated. Please remove it.

            For instance:
                -     .Attr("out_num", out_num, "AttrTypeInt64")
                +     .Attr("out_num", out_num)
                        """
            )
            print(traceback.format_stack()[-2])

        attribute = user_op_attr_cfg.AttrValue()
        assert isinstance(attr_name, str)
        assert self._op_type_name is not None
        attr_type = oneflow_api.GetUserOpAttrType(self._op_type_name, attr_name)
        if attr_type == user_op_attr_cfg.kAtInt32:
            assert isinstance(attr_value, int)
            attribute.set_at_int32(attr_value)
        elif attr_type == user_op_attr_cfg.kAtInt64:
            assert isinstance(attr_value, int)
            attribute.set_at_int64(attr_value)
        elif attr_type == user_op_attr_cfg.kAtBool:
            assert isinstance(attr_value, bool)
            attribute.set_at_bool(attr_value)
        elif attr_type == user_op_attr_cfg.kAtFloat:
            assert isinstance(attr_value, (float, int))
            attribute.set_at_float(attr_value)
        elif attr_type == user_op_attr_cfg.kAtDouble:
            assert isinstance(attr_value, (float, int))
            attribute.set_at_double(attr_value)
        elif attr_type == user_op_attr_cfg.kAtString:
            assert isinstance(attr_value, str)
            attribute.set_at_string(attr_value)
        elif attr_type == user_op_attr_cfg.kAtShape:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_shape = attribute.mutable_at_shape()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_shape.add_dim(x)
        elif attr_type == user_op_attr_cfg.kAtDataType:
            assert attr_value in oneflow.dtypes()
            attr_value = oneflow_api.deprecated.GetProtoDtype4OfDtype(attr_value)
            assert isinstance(attr_value, int)
            attribute.set_at_data_type(data_type_cfg.DataType(attr_value))
        elif attr_type == user_op_attr_cfg.kAtListInt32:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_list_int32.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListInt64:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_list_int64.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListFloat:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_float = attribute.mutable_at_list_float()
            for x in attr_value:
                assert isinstance(x, (float, int))
                attribute_mutable_at_list_float.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListDataType:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type()
            for x in attr_value:
                assert x in oneflow.dtypes()
                x = oneflow_api.deprecated.GetProtoDtype4OfDtype(x)
                assert isinstance(x, int)
                attribute_mutable_at_list_data_type.add_val(data_type_cfg.DataType(x))
        elif attr_type == user_op_attr_cfg.kAtListShape:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_shape = (
                attribute.mutable_at_list_shape().mutable_val()
            )
            for x in attr_value:
                assert isinstance(x, (tuple, list))
                shape = shape_cfg.ShapeProto()
                for dim in x:
                    assert isinstance(dim, int)
                    shape.add_dim(dim)
                attribute_mutable_at_list_shape.Add().CopyFrom(shape)
        elif attr_type == user_op_attr_cfg.kAtListString:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_string = attribute.mutable_at_list_string()
            for x in attr_value:
                assert isinstance(x, str)
                attribute_mutable_at_list_string.add_val(x)
        else:
            raise ValueError("Invalid op attribute type {}".format(attr_type))

        self._builder.attr(attr_name, attribute)
        return self
Beispiel #2
0
    def Attr(self, attr_name, attr_value, attr_type_name=None):
        r"""Set value of op's attribute.

        Args:
            attr_name (str): attribute name of op
            attr_value (Any): attribute value of op

        Raises:
            ValueError: raised when value is not idential to op's attribute type.

        Returns:
            [type]: [description]
        """
        if attr_type_name != None:
            print(
                """WARNING: Argument 'attr_type_name' of UserOpConfBuilder.Attr has been deprecated. Please remove it.
            
            For instance:
                -     .Attr("out_num", out_num, "AttrTypeInt64")
                +     .Attr("out_num", out_num)
                        """)
            print(traceback.format_stack()[-2])

        attribute = attr_value_pb.AttrValue()
        assert isinstance(attr_name, str)
        attr_type = oneflow_api.GetUserOpAttrType(
            self.user_op_.op_conf_.user_conf.op_type_name, attr_name)
        if attr_type == attr_value_pb.kAtInt32:
            assert isinstance(attr_value, int)
            attribute.at_int32 = attr_value
        elif attr_type == attr_value_pb.kAtInt64:
            assert isinstance(attr_value, int)
            attribute.at_int64 = attr_value
        elif attr_type == attr_value_pb.kAtBool:
            assert isinstance(attr_value, bool)
            attribute.at_bool = attr_value
        elif attr_type == attr_value_pb.kAtFloat:
            assert isinstance(attr_value, float)
            attribute.at_float = attr_value
        elif attr_type == attr_value_pb.kAtDouble:
            assert isinstance(attr_value, float)
            attribute.at_double = attr_value
        elif attr_type == attr_value_pb.kAtString:
            assert isinstance(attr_value, str)
            attribute.at_string = attr_value
        elif attr_type == attr_value_pb.kAtShape:
            assert isinstance(attr_value, (tuple, list))
            assert all(isinstance(x, int) for x in attr_value)
            attribute.at_shape.dim[:] = list(attr_value)
        elif attr_type == attr_value_pb.kAtDataType:
            assert (isinstance(attr_value.oneflow_proto_dtype, int)
                    and attr_value in oneflow.dtypes())
            attribute.at_data_type = attr_value.oneflow_proto_dtype
        elif attr_type == attr_value_pb.kAtListInt32:
            assert isinstance(attr_value, (tuple, list))
            assert all(isinstance(x, int) for x in attr_value)
            attribute.at_list_int32.val[:] = list(attr_value)
        elif attr_type == attr_value_pb.kAtListInt64:
            assert isinstance(attr_value, (tuple, list))
            assert all(isinstance(x, int) for x in attr_value)
            attribute.at_list_int64.val[:] = list(attr_value)
        elif attr_type == attr_value_pb.kAtListFloat:
            assert isinstance(attr_value, (tuple, list))
            assert all(isinstance(x, float) for x in attr_value)
            attribute.at_list_float.val[:] = list(attr_value)
        elif attr_type == attr_value_pb.kAtListDataType:
            assert isinstance(attr_value, (tuple, list))
            assert all(
                isinstance(x.oneflow_proto_dtype, int)
                and x in oneflow.dtypes() for x in attr_value)
            attribute.at_list_data_type.val[:] = list(
                [x.oneflow_proto_dtype for x in attr_value])
        elif attr_type == attr_value_pb.kAtListShape:
            assert isinstance(attr_value, (tuple, list))
            assert all(
                isinstance(x, tuple) or isinstance(x, list)
                for x in attr_value)
            for i in range(len(attr_value)):
                shape = shape_util.ShapeProto()
                shape.dim[:] = list(attr_value[i])
                attribute.at_list_shape.val.append(shape)
        elif attr_type == attr_value_pb.kAtListString:
            assert isinstance(attr_value, (tuple, list))
            assert all(isinstance(x, str) for x in attr_value)
            attribute.at_list_string.val[:] = list(attr_value)
        else:
            raise ValueError("Invalid op attribute type {}".format(attr_type))

        self.user_op_.op_conf_.user_conf.attr[attr_name].CopyFrom(attribute)
        return self
Beispiel #3
0
def convert_to_user_attr_value(op_type_name, attr_name, attr_value):
    attribute = user_op_attr_cfg.AttrValue()
    assert isinstance(attr_name, str)
    attr_type = oneflow_api.GetUserOpAttrType(op_type_name, attr_name)
    if attr_type == user_op_attr_cfg.kAtInt32:
        assert isinstance(attr_value, int)
        attribute.set_at_int32(attr_value)
    elif attr_type == user_op_attr_cfg.kAtInt64:
        assert isinstance(attr_value, int)
        attribute.set_at_int64(attr_value)
    elif attr_type == user_op_attr_cfg.kAtBool:
        assert isinstance(attr_value, bool)
        attribute.set_at_bool(attr_value)
    elif attr_type == user_op_attr_cfg.kAtFloat:
        assert isinstance(attr_value, (float, int))
        attribute.set_at_float(attr_value)
    elif attr_type == user_op_attr_cfg.kAtDouble:
        assert isinstance(attr_value, (float, int))
        attribute.set_at_double(attr_value)
    elif attr_type == user_op_attr_cfg.kAtString:
        assert isinstance(attr_value, str)
        attribute.set_at_string(attr_value)
    elif attr_type == user_op_attr_cfg.kAtShape:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_shape = attribute.mutable_at_shape()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_shape.add_dim(x)
    elif attr_type == user_op_attr_cfg.kAtDataType:
        assert attr_value in oneflow.dtypes()
        attr_value = oneflow_api.deprecated.GetProtoDtype4OfDtype(attr_value)
        assert isinstance(attr_value, int)
        attribute.set_at_data_type(data_type_cfg.DataType(attr_value))
    elif attr_type == user_op_attr_cfg.kAtListInt32:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_list_int32.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListInt64:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_list_int64.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListFloat:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_float = attribute.mutable_at_list_float()
        for x in attr_value:
            assert isinstance(x, (float, int))
            attribute_mutable_at_list_float.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListDataType:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type()
        for x in attr_value:
            assert x in oneflow.dtypes()
            x = oneflow_api.deprecated.GetProtoDtype4OfDtype(x)
            assert isinstance(x, int)
            attribute_mutable_at_list_data_type.add_val(data_type_cfg.DataType(x))
    elif attr_type == user_op_attr_cfg.kAtListShape:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_shape = (
            attribute.mutable_at_list_shape().mutable_val()
        )
        for x in attr_value:
            assert isinstance(x, (tuple, list))
            shape = shape_cfg.ShapeProto()
            for dim in x:
                assert isinstance(dim, int)
                shape.add_dim(dim)
            attribute_mutable_at_list_shape.Add().CopyFrom(shape)
    elif attr_type == user_op_attr_cfg.kAtListString:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_string = attribute.mutable_at_list_string()
        for x in attr_value:
            assert isinstance(x, str)
            attribute_mutable_at_list_string.add_val(x)
    else:
        raise ValueError("Invalid op attribute type {}".format(attr_type))
    return attribute