Beispiel #1
0
    def Attr(self, attr_name, attr_value, attr_type_name=None):
        r"""Set value of op's attribute.

        Args:
            attr_name (str): attribute name of op
            attr_value (Any): attribute value of op

        Raises:
            ValueError: raised when value is not idential to op's attribute type.

        Returns:
            [type]: [description]
        """
        if attr_type_name != None:
            print(
                """WARNING: Argument 'attr_type_name' of UserOpConfBuilder.Attr has been deprecated. Please remove it.
            
            For instance:
                -     .Attr("out_num", out_num, "AttrTypeInt64")
                +     .Attr("out_num", out_num)
                        """)
            print(traceback.format_stack()[-2])

        attribute = attr_value_pb.AttrValue()
        assert isinstance(attr_name, str)
        attr_type = oneflow_api.GetUserOpAttrType(
            self.user_op_.op_conf_.user_conf.op_type_name, attr_name)
        if attr_type == attr_value_pb.kAtInt32:
            assert isinstance(attr_value, int)
            attribute.at_int32 = attr_value
        elif attr_type == attr_value_pb.kAtInt64:
            assert isinstance(attr_value, int)
            attribute.at_int64 = attr_value
        elif attr_type == attr_value_pb.kAtBool:
            assert isinstance(attr_value, bool)
            attribute.at_bool = attr_value
        elif attr_type == attr_value_pb.kAtFloat:
            assert isinstance(attr_value, float)
            attribute.at_float = attr_value
        elif attr_type == attr_value_pb.kAtDouble:
            assert isinstance(attr_value, float)
            attribute.at_double = attr_value
        elif attr_type == attr_value_pb.kAtString:
            assert isinstance(attr_value, str)
            attribute.at_string = attr_value
        elif attr_type == attr_value_pb.kAtShape:
            assert isinstance(attr_value, (tuple, list))
            assert all(isinstance(x, int) for x in attr_value)
            attribute.at_shape.dim[:] = list(attr_value)
        elif attr_type == attr_value_pb.kAtDataType:
            assert (isinstance(attr_value.oneflow_proto_dtype, int)
                    and attr_value in oneflow.dtypes())
            attribute.at_data_type = attr_value.oneflow_proto_dtype
        elif attr_type == attr_value_pb.kAtListInt32:
            assert isinstance(attr_value, (tuple, list))
            assert all(isinstance(x, int) for x in attr_value)
            attribute.at_list_int32.val[:] = list(attr_value)
        elif attr_type == attr_value_pb.kAtListInt64:
            assert isinstance(attr_value, (tuple, list))
            assert all(isinstance(x, int) for x in attr_value)
            attribute.at_list_int64.val[:] = list(attr_value)
        elif attr_type == attr_value_pb.kAtListFloat:
            assert isinstance(attr_value, (tuple, list))
            assert all(isinstance(x, float) for x in attr_value)
            attribute.at_list_float.val[:] = list(attr_value)
        elif attr_type == attr_value_pb.kAtListDataType:
            assert isinstance(attr_value, (tuple, list))
            assert all(
                isinstance(x.oneflow_proto_dtype, int)
                and x in oneflow.dtypes() for x in attr_value)
            attribute.at_list_data_type.val[:] = list(
                [x.oneflow_proto_dtype for x in attr_value])
        elif attr_type == attr_value_pb.kAtListShape:
            assert isinstance(attr_value, (tuple, list))
            assert all(
                isinstance(x, tuple) or isinstance(x, list)
                for x in attr_value)
            for i in range(len(attr_value)):
                shape = shape_util.ShapeProto()
                shape.dim[:] = list(attr_value[i])
                attribute.at_list_shape.val.append(shape)
        elif attr_type == attr_value_pb.kAtListString:
            assert isinstance(attr_value, (tuple, list))
            assert all(isinstance(x, str) for x in attr_value)
            attribute.at_list_string.val[:] = list(attr_value)
        else:
            raise ValueError("Invalid op attribute type {}".format(attr_type))

        self.user_op_.op_conf_.user_conf.attr[attr_name].CopyFrom(attribute)
        return self
Beispiel #2
0
    def Attr(self, attr_name, attr_value, attr_type_name=None):
        """Set value of op's attribute.

        Args:
            attr_name (str): attribute name of op
            attr_value (Any): attribute value of op

        Raises:
            ValueError: raised when value is not idential to op's attribute type.

        Returns:
            [type]: [description]
        """
        if attr_type_name != None:
            print(
                'WARNING: Argument \'attr_type_name\' of UserOpConfBuilder.Attr has been deprecated. Please remove it.\n\n            For instance:\n                -     .Attr("out_num", out_num, "AttrTypeInt64")\n                +     .Attr("out_num", out_num)\n                        '
            )
            print(traceback.format_stack()[-2])
        attribute = user_op_attr_cfg.AttrValue()
        assert isinstance(attr_name, str)
        attr_type = oneflow._oneflow_internal.GetUserOpAttrType(
            self.user_op_.op_conf_.user_conf.op_type_name, attr_name)
        if attr_type == user_op_attr_cfg.kAtInt32:
            assert isinstance(attr_value, int)
            attribute.set_at_int32(attr_value)
        elif attr_type == user_op_attr_cfg.kAtInt64:
            assert isinstance(attr_value, int)
            attribute.set_at_int64(attr_value)
        elif attr_type == user_op_attr_cfg.kAtBool:
            assert isinstance(attr_value, bool)
            attribute.set_at_bool(attr_value)
        elif attr_type == user_op_attr_cfg.kAtFloat:
            assert isinstance(attr_value, (float, int))
            attribute.set_at_float(attr_value)
        elif attr_type == user_op_attr_cfg.kAtDouble:
            assert isinstance(attr_value, (float, int))
            attribute.set_at_double(attr_value)
        elif attr_type == user_op_attr_cfg.kAtString:
            assert isinstance(attr_value, str)
            attribute.set_at_string(attr_value)
        elif attr_type == user_op_attr_cfg.kAtShape:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_shape = attribute.mutable_at_shape()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_shape.add_dim(x)
        elif attr_type == user_op_attr_cfg.kAtDataType:
            assert attr_value in flow.dtypes()
            attr_value = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
                attr_value)
            assert isinstance(attr_value, int)
            attribute.set_at_data_type(data_type_cfg.DataType(attr_value))
        elif attr_type == user_op_attr_cfg.kAtListInt32:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_list_int32.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListInt64:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_list_int64.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListFloat:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_float = attribute.mutable_at_list_float()
            for x in attr_value:
                assert isinstance(x, (float, int))
                attribute_mutable_at_list_float.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListDataType:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type(
            )
            for x in attr_value:
                assert x in flow.dtypes()
                x = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
                    x)
                assert isinstance(x, int)
                attribute_mutable_at_list_data_type.add_val(
                    data_type_cfg.DataType(x))
        elif attr_type == user_op_attr_cfg.kAtListShape:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_shape = (
                attribute.mutable_at_list_shape().mutable_val())
            for x in attr_value:
                assert isinstance(x, (tuple, list))
                shape = shape_cfg.ShapeProto()
                for dim in x:
                    assert isinstance(dim, int)
                    shape.add_dim(dim)
                attribute_mutable_at_list_shape.Add().CopyFrom(shape)
        elif attr_type == user_op_attr_cfg.kAtListString:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_string = attribute.mutable_at_list_string(
            )
            for x in attr_value:
                assert isinstance(x, str)
                attribute_mutable_at_list_string.add_val(x)
        else:
            raise ValueError("Invalid op attribute type {}".format(attr_type))
        self.user_op_.op_conf_.user_conf.attr[attr_name].CopyFrom(
            text_format.Parse(str(attribute), attr_value_pb.AttrValue()))
        return self