Example #1
0
def convert_to_user_attr_value(op_type_name, attr_name, attr_value):
    attribute = user_op_attr_cfg.AttrValue()
    assert isinstance(attr_name, str)
    attr_type = oneflow._oneflow_internal.GetUserOpAttrType(
        op_type_name, attr_name)
    if attr_type == user_op_attr_cfg.kAtInt32:
        assert isinstance(attr_value, int)
        attribute.set_at_int32(attr_value)
    elif attr_type == user_op_attr_cfg.kAtInt64:
        assert isinstance(attr_value, int)
        attribute.set_at_int64(attr_value)
    elif attr_type == user_op_attr_cfg.kAtBool:
        assert isinstance(attr_value, bool)
        attribute.set_at_bool(attr_value)
    elif attr_type == user_op_attr_cfg.kAtFloat:
        assert isinstance(attr_value, (float, int))
        attribute.set_at_float(attr_value)
    elif attr_type == user_op_attr_cfg.kAtDouble:
        assert isinstance(attr_value, (float, int))
        attribute.set_at_double(attr_value)
    elif attr_type == user_op_attr_cfg.kAtString:
        assert isinstance(attr_value, str)
        attribute.set_at_string(attr_value)
    elif attr_type == user_op_attr_cfg.kAtShape:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_shape = attribute.mutable_at_shape()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_shape.add_dim(x)
    elif attr_type == user_op_attr_cfg.kAtDataType:
        assert attr_value in oneflow.dtypes()
        attr_value = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
            attr_value)
        assert isinstance(attr_value, int)
        attribute.set_at_data_type(data_type_cfg.DataType(attr_value))
    elif attr_type == user_op_attr_cfg.kAtListInt32:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_list_int32.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListInt64:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_list_int64.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListFloat:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_float = attribute.mutable_at_list_float()
        for x in attr_value:
            assert isinstance(x, (float, int))
            attribute_mutable_at_list_float.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListDataType:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type(
        )
        for x in attr_value:
            assert x in oneflow.dtypes()
            x = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(x)
            assert isinstance(x, int)
            attribute_mutable_at_list_data_type.add_val(
                data_type_cfg.DataType(x))
    elif attr_type == user_op_attr_cfg.kAtListShape:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_shape = (
            attribute.mutable_at_list_shape().mutable_val())
        for x in attr_value:
            assert isinstance(x, (tuple, list))
            shape = shape_cfg.ShapeProto()
            for dim in x:
                assert isinstance(dim, int)
                shape.add_dim(dim)
            attribute_mutable_at_list_shape.Add().CopyFrom(shape)
    elif attr_type == user_op_attr_cfg.kAtListString:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_string = attribute.mutable_at_list_string()
        for x in attr_value:
            assert isinstance(x, str)
            attribute_mutable_at_list_string.add_val(x)
    else:
        raise ValueError("Invalid op attribute type {}".format(attr_type))
    return attribute
Example #2
0
    def Attr(self, attr_name, attr_value, attr_type_name=None):
        """Set value of op's attribute.

        Args:
            attr_name (str): attribute name of op
            attr_value (Any): attribute value of op

        Raises:
            ValueError: raised when value is not idential to op's attribute type.

        Returns:
            [type]: [description]
        """
        if attr_type_name != None:
            print(
                'WARNING: Argument \'attr_type_name\' of UserOpConfBuilder.Attr has been deprecated. Please remove it.\n\n            For instance:\n                -     .Attr("out_num", out_num, "AttrTypeInt64")\n                +     .Attr("out_num", out_num)\n                        '
            )
            print(traceback.format_stack()[-2])
        attribute = user_op_attr_cfg.AttrValue()
        assert isinstance(attr_name, str)
        attr_type = oneflow._oneflow_internal.GetUserOpAttrType(
            self.user_op_.op_conf_.user_conf.op_type_name, attr_name)
        if attr_type == user_op_attr_cfg.kAtInt32:
            assert isinstance(attr_value, int)
            attribute.set_at_int32(attr_value)
        elif attr_type == user_op_attr_cfg.kAtInt64:
            assert isinstance(attr_value, int)
            attribute.set_at_int64(attr_value)
        elif attr_type == user_op_attr_cfg.kAtBool:
            assert isinstance(attr_value, bool)
            attribute.set_at_bool(attr_value)
        elif attr_type == user_op_attr_cfg.kAtFloat:
            assert isinstance(attr_value, (float, int))
            attribute.set_at_float(attr_value)
        elif attr_type == user_op_attr_cfg.kAtDouble:
            assert isinstance(attr_value, (float, int))
            attribute.set_at_double(attr_value)
        elif attr_type == user_op_attr_cfg.kAtString:
            assert isinstance(attr_value, str)
            attribute.set_at_string(attr_value)
        elif attr_type == user_op_attr_cfg.kAtShape:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_shape = attribute.mutable_at_shape()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_shape.add_dim(x)
        elif attr_type == user_op_attr_cfg.kAtDataType:
            assert attr_value in flow.dtypes()
            attr_value = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
                attr_value)
            assert isinstance(attr_value, int)
            attribute.set_at_data_type(data_type_cfg.DataType(attr_value))
        elif attr_type == user_op_attr_cfg.kAtListInt32:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_list_int32.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListInt64:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_list_int64.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListFloat:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_float = attribute.mutable_at_list_float()
            for x in attr_value:
                assert isinstance(x, (float, int))
                attribute_mutable_at_list_float.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListDataType:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type(
            )
            for x in attr_value:
                assert x in flow.dtypes()
                x = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
                    x)
                assert isinstance(x, int)
                attribute_mutable_at_list_data_type.add_val(
                    data_type_cfg.DataType(x))
        elif attr_type == user_op_attr_cfg.kAtListShape:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_shape = (
                attribute.mutable_at_list_shape().mutable_val())
            for x in attr_value:
                assert isinstance(x, (tuple, list))
                shape = shape_cfg.ShapeProto()
                for dim in x:
                    assert isinstance(dim, int)
                    shape.add_dim(dim)
                attribute_mutable_at_list_shape.Add().CopyFrom(shape)
        elif attr_type == user_op_attr_cfg.kAtListString:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_string = attribute.mutable_at_list_string(
            )
            for x in attr_value:
                assert isinstance(x, str)
                attribute_mutable_at_list_string.add_val(x)
        else:
            raise ValueError("Invalid op attribute type {}".format(attr_type))
        self.user_op_.op_conf_.user_conf.attr[attr_name].CopyFrom(
            text_format.Parse(str(attribute), attr_value_pb.AttrValue()))
        return self