def _GetInterfaceBlobObject(builder, op_name):
    sess = session_ctx.GetDefaultSession()
    if oneflow._oneflow_internal.EagerExecutionEnabled():
        return sess.var_name2var_blob[op_name].blob_object
    sess = session_ctx.GetDefaultSession()
    op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
    cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString(
        str(op_attribute))
    parallel_conf = sess.ParallelConf4LazyInterfaceOpName(op_name)
    if not isinstance(
            parallel_conf,
            oneflow._oneflow_internal.oneflow.core.job.placement.ParallelConf):
        parallel_conf_cfg = placement_cfg.ParallelConf()
        parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
        for device_name in parallel_conf.device_name:
            parallel_conf_cfg.add_device_name(device_name)
        if parallel_conf.HasField("hierarchy"):
            hierarchy = shape_proto_cfg.ShapeProto()
            for dim in parallel_conf.hierarchy.dim:
                hierarchy.add_dim(dim)
            assert hierarchy.dim_size() > 0
            parallel_conf_cfg.mutable_hierarchy().CopyFrom(hierarchy)
        parallel_conf = parallel_conf_cfg

    blob_object = builder.MakeLazyRefBlobObject(op_name, cfg_op_attribute,
                                                parallel_conf)
    return blob_object
Beispiel #2
0
def _inferface_blob_conf_proto_to_cfg(inferface_blob_conf_proto,
                                      mut_inferface_blob_conf_cfg):
    assert isinstance(inferface_blob_conf_proto,
                      interface_blob_conf_proto.InterfaceBlobConf)
    assert isinstance(mut_inferface_blob_conf_cfg,
                      interface_blob_conf_proto_cfg.InterfaceBlobConf)

    shape = shape_proto_cfg.ShapeProto()
    for dim in inferface_blob_conf_proto.shape.dim:
        shape.add_dim(dim)

    mut_inferface_blob_conf_cfg.mutable_shape().CopyFrom(shape)
    dtype = dtype_proto_cfg.DataType(int(inferface_blob_conf_proto.data_type))
    mut_inferface_blob_conf_cfg.set_data_type(dtype)

    if inferface_blob_conf_proto.HasField("parallel_distribution"):
        # TODO(guoran): Process Nd sbp, parallel_distribution_cfg CopyFrom parallel_distribution_proto
        assert len(
            inferface_blob_conf_proto.parallel_distribution.sbp_parallel) == 1
        sbp_proto = inferface_blob_conf_proto.parallel_distribution.sbp_parallel[
            0]
        if sbp_proto.HasField("split_parallel"):
            split_axis = sbp_proto.split_parallel.axis
            sbp = sbp_parallel_cfg.SbpParallel()
            sbp.mutable_split_parallel().set_axis(split_axis)
            mut_inferface_blob_conf_cfg.mutable_parallel_distribution(
            ).mutable_sbp_parallel().Add().CopyFrom(sbp)

    mut_inferface_blob_conf_cfg.set_is_dynamic(
        inferface_blob_conf_proto.is_dynamic)
Beispiel #3
0
def ReplaceDeviceTag(parallel_desc_symbol, device_tag, builder=None):
    assert parallel_desc_symbol.device_tag != device_tag
    parallel_conf = placement_cfg.ParallelConf()
    parallel_conf.set_device_tag(device_tag)
    for device_name in parallel_desc_symbol.parallel_conf.device_name():
        parallel_conf.add_device_name(device_name)
    hierarchy = shape_proto_cfg.ShapeProto()
    for dim in parallel_desc_symbol.hierarchy:
        hierarchy.add_dim(dim)
    assert hierarchy.dim_size() > 0
    parallel_conf.mutable_hierarchy().CopyFrom(hierarchy)
    if builder is None:
        return oneflow._oneflow_internal.PlacementSymbol(
            parallel_desc_symbol.symbol_id, parallel_conf)
    else:
        return builder.GetParallelDescSymbol(parallel_conf)
Beispiel #4
0
def convert_to_user_attr_value(op_type_name, attr_name, attr_value):
    attribute = user_op_attr_cfg.AttrValue()
    assert isinstance(attr_name, str)
    attr_type = oneflow._oneflow_internal.GetUserOpAttrType(
        op_type_name, attr_name)
    if attr_type == user_op_attr_cfg.kAtInt32:
        assert isinstance(attr_value, int)
        attribute.set_at_int32(attr_value)
    elif attr_type == user_op_attr_cfg.kAtInt64:
        assert isinstance(attr_value, int)
        attribute.set_at_int64(attr_value)
    elif attr_type == user_op_attr_cfg.kAtBool:
        assert isinstance(attr_value, bool)
        attribute.set_at_bool(attr_value)
    elif attr_type == user_op_attr_cfg.kAtFloat:
        assert isinstance(attr_value, (float, int))
        attribute.set_at_float(attr_value)
    elif attr_type == user_op_attr_cfg.kAtDouble:
        assert isinstance(attr_value, (float, int))
        attribute.set_at_double(attr_value)
    elif attr_type == user_op_attr_cfg.kAtString:
        assert isinstance(attr_value, str)
        attribute.set_at_string(attr_value)
    elif attr_type == user_op_attr_cfg.kAtShape:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_shape = attribute.mutable_at_shape()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_shape.add_dim(x)
    elif attr_type == user_op_attr_cfg.kAtDataType:
        assert attr_value in oneflow.dtypes()
        attr_value = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
            attr_value)
        assert isinstance(attr_value, int)
        attribute.set_at_data_type(data_type_cfg.DataType(attr_value))
    elif attr_type == user_op_attr_cfg.kAtListInt32:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_list_int32.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListInt64:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_list_int64.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListFloat:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_float = attribute.mutable_at_list_float()
        for x in attr_value:
            assert isinstance(x, (float, int))
            attribute_mutable_at_list_float.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListDataType:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type(
        )
        for x in attr_value:
            assert x in oneflow.dtypes()
            x = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(x)
            assert isinstance(x, int)
            attribute_mutable_at_list_data_type.add_val(
                data_type_cfg.DataType(x))
    elif attr_type == user_op_attr_cfg.kAtListShape:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_shape = (
            attribute.mutable_at_list_shape().mutable_val())
        for x in attr_value:
            assert isinstance(x, (tuple, list))
            shape = shape_cfg.ShapeProto()
            for dim in x:
                assert isinstance(dim, int)
                shape.add_dim(dim)
            attribute_mutable_at_list_shape.Add().CopyFrom(shape)
    elif attr_type == user_op_attr_cfg.kAtListString:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_string = attribute.mutable_at_list_string()
        for x in attr_value:
            assert isinstance(x, str)
            attribute_mutable_at_list_string.add_val(x)
    else:
        raise ValueError("Invalid op attribute type {}".format(attr_type))
    return attribute
Beispiel #5
0
    def Attr(self, attr_name, attr_value, attr_type_name=None):
        """Set value of op's attribute.

        Args:
            attr_name (str): attribute name of op
            attr_value (Any): attribute value of op

        Raises:
            ValueError: raised when value is not idential to op's attribute type.

        Returns:
            [type]: [description]
        """
        if attr_type_name != None:
            print(
                'WARNING: Argument \'attr_type_name\' of UserOpConfBuilder.Attr has been deprecated. Please remove it.\n\n            For instance:\n                -     .Attr("out_num", out_num, "AttrTypeInt64")\n                +     .Attr("out_num", out_num)\n                        '
            )
            print(traceback.format_stack()[-2])
        attribute = user_op_attr_cfg.AttrValue()
        assert isinstance(attr_name, str)
        attr_type = oneflow._oneflow_internal.GetUserOpAttrType(
            self.user_op_.op_conf_.user_conf.op_type_name, attr_name)
        if attr_type == user_op_attr_cfg.kAtInt32:
            assert isinstance(attr_value, int)
            attribute.set_at_int32(attr_value)
        elif attr_type == user_op_attr_cfg.kAtInt64:
            assert isinstance(attr_value, int)
            attribute.set_at_int64(attr_value)
        elif attr_type == user_op_attr_cfg.kAtBool:
            assert isinstance(attr_value, bool)
            attribute.set_at_bool(attr_value)
        elif attr_type == user_op_attr_cfg.kAtFloat:
            assert isinstance(attr_value, (float, int))
            attribute.set_at_float(attr_value)
        elif attr_type == user_op_attr_cfg.kAtDouble:
            assert isinstance(attr_value, (float, int))
            attribute.set_at_double(attr_value)
        elif attr_type == user_op_attr_cfg.kAtString:
            assert isinstance(attr_value, str)
            attribute.set_at_string(attr_value)
        elif attr_type == user_op_attr_cfg.kAtShape:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_shape = attribute.mutable_at_shape()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_shape.add_dim(x)
        elif attr_type == user_op_attr_cfg.kAtDataType:
            assert attr_value in flow.dtypes()
            attr_value = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
                attr_value)
            assert isinstance(attr_value, int)
            attribute.set_at_data_type(data_type_cfg.DataType(attr_value))
        elif attr_type == user_op_attr_cfg.kAtListInt32:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_list_int32.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListInt64:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_list_int64.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListFloat:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_float = attribute.mutable_at_list_float()
            for x in attr_value:
                assert isinstance(x, (float, int))
                attribute_mutable_at_list_float.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListDataType:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type(
            )
            for x in attr_value:
                assert x in flow.dtypes()
                x = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
                    x)
                assert isinstance(x, int)
                attribute_mutable_at_list_data_type.add_val(
                    data_type_cfg.DataType(x))
        elif attr_type == user_op_attr_cfg.kAtListShape:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_shape = (
                attribute.mutable_at_list_shape().mutable_val())
            for x in attr_value:
                assert isinstance(x, (tuple, list))
                shape = shape_cfg.ShapeProto()
                for dim in x:
                    assert isinstance(dim, int)
                    shape.add_dim(dim)
                attribute_mutable_at_list_shape.Add().CopyFrom(shape)
        elif attr_type == user_op_attr_cfg.kAtListString:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_string = attribute.mutable_at_list_string(
            )
            for x in attr_value:
                assert isinstance(x, str)
                attribute_mutable_at_list_string.add_val(x)
        else:
            raise ValueError("Invalid op attribute type {}".format(attr_type))
        self.user_op_.op_conf_.user_conf.attr[attr_name].CopyFrom(
            text_format.Parse(str(attribute), attr_value_pb.AttrValue()))
        return self