예제 #1
0
def _GetInterfaceBlobObject(builder, op_name):
    sess = session_ctx.GetDefaultSession()
    if oneflow_api.EagerExecutionEnabled():
        return sess.var_name2var_blob[op_name].blob_object
    sess = session_ctx.GetDefaultSession()
    op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
    cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(str(op_attribute))
    parallel_conf = sess.ParallelConf4LazyInterfaceOpName(op_name)
    if not isinstance(
        parallel_conf, oneflow_api.oneflow.core.job.placement.ParallelConf
    ):
        parallel_conf_cfg = placement_cfg.ParallelConf()
        parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
        for device_name in parallel_conf.device_name:
            parallel_conf_cfg.add_device_name(device_name)
        if parallel_conf.HasField("hierarchy"):
            hierarchy = shape_proto_cfg.ShapeProto()
            for dim in parallel_conf.hierarchy.dim:
                hierarchy.add_dim(dim)
            assert hierarchy.dim_size() > 0
            parallel_conf_cfg.mutable_hierarchy().CopyFrom(hierarchy)
        parallel_conf = parallel_conf_cfg

    blob_object = builder.MakeLazyRefBlobObject(
        op_name, cfg_op_attribute, parallel_conf
    )
    return blob_object
예제 #2
0
def _inferface_blob_conf_proto_to_cfg(inferface_blob_conf_proto,
                                      mut_inferface_blob_conf_cfg):
    assert isinstance(inferface_blob_conf_proto,
                      interface_blob_conf_proto.InterfaceBlobConf)
    assert isinstance(mut_inferface_blob_conf_cfg,
                      interface_blob_conf_proto_cfg.InterfaceBlobConf)

    shape = shape_proto_cfg.ShapeProto()
    for dim in inferface_blob_conf_proto.shape.dim:
        shape.add_dim(dim)

    mut_inferface_blob_conf_cfg.mutable_shape().CopyFrom(shape)
    dtype = dtype_proto_cfg.DataType(int(inferface_blob_conf_proto.data_type))
    mut_inferface_blob_conf_cfg.set_data_type(dtype)

    if inferface_blob_conf_proto.HasField("parallel_distribution"):
        # TODO(guoran): Process Nd sbp, parallel_distribution_cfg CopyFrom parallel_distribution_proto
        assert len(
            inferface_blob_conf_proto.parallel_distribution.sbp_parallel) == 1
        sbp_proto = inferface_blob_conf_proto.parallel_distribution.sbp_parallel[
            0]
        if sbp_proto.HasField("split_parallel"):
            split_axis = sbp_proto.split_parallel.axis
            sbp = sbp_parallel_cfg.SbpParallel()
            sbp.mutable_split_parallel().set_axis(split_axis)
            mut_inferface_blob_conf_cfg.mutable_parallel_distribution(
            ).mutable_sbp_parallel().Add().CopyFrom(sbp)

    mut_inferface_blob_conf_cfg.set_is_dynamic(
        inferface_blob_conf_proto.is_dynamic)
예제 #3
0
def _inferface_blob_conf_proto_to_cfg(
    inferface_blob_conf_proto, mut_inferface_blob_conf_cfg
):
    assert isinstance(
        inferface_blob_conf_proto, interface_blob_conf_proto.InterfaceBlobConf
    )
    assert isinstance(
        mut_inferface_blob_conf_cfg, interface_blob_conf_proto_cfg.InterfaceBlobConf
    )

    shape = shape_proto_cfg.ShapeProto()
    for dim in inferface_blob_conf_proto.shape.dim:
        shape.add_dim(dim)

    mut_inferface_blob_conf_cfg.mutable_shape().CopyFrom(shape)
    dtype = dtype_proto_cfg.DataType(int(inferface_blob_conf_proto.data_type))
    mut_inferface_blob_conf_cfg.set_data_type(dtype)

    split_axis = dtype_proto_cfg.OptInt64()
    if inferface_blob_conf_proto.split_axis.HasField("value"):
        split_axis.set_value(inferface_blob_conf_proto.split_axis.value)
    mut_inferface_blob_conf_cfg.mutable_split_axis().CopyFrom(split_axis)

    batch_axis = dtype_proto_cfg.OptInt64()
    if inferface_blob_conf_proto.batch_axis.HasField("value"):
        batch_axis.set_value(inferface_blob_conf_proto.batch_axis.value)
    mut_inferface_blob_conf_cfg.mutable_batch_axis().CopyFrom(batch_axis)

    mut_inferface_blob_conf_cfg.set_is_dynamic(inferface_blob_conf_proto.is_dynamic)
    mut_inferface_blob_conf_cfg.set_is_tensor_list(
        inferface_blob_conf_proto.is_tensor_list
    )
예제 #4
0
def GetConcatSplitBoxingParallelDescSymbol(builder, blob_parallel_desc_symbol,
                                           max_parallel_num):
    random_rank_id = random.randint(0, max_parallel_num - 1)
    parallel_conf = placement_cfg.ParallelConf()
    parallel_conf.set_device_tag("cpu")
    for machine_id, _ in blob_parallel_desc_symbol.machine_id2device_id_list.items(
    ):
        parallel_conf.add_device_name("%s:%s" % (machine_id, random_rank_id))
    hierarchy = shape_proto_cfg.ShapeProto()
    for dim in blob_parallel_desc_symbol.hierarchy:
        hierarchy.add_dim(dim)
    assert hierarchy.dim_size() > 0
    parallel_conf.mutable_hierarchy().CopyFrom(hierarchy)
    return builder.GetParallelDescSymbol(parallel_conf)
예제 #5
0
def ReplaceDeviceTag(parallel_desc_symbol, device_tag, builder=None):
    assert parallel_desc_symbol.device_tag != device_tag
    parallel_conf = placement_cfg.ParallelConf()
    parallel_conf.set_device_tag(device_tag)
    for device_name in parallel_desc_symbol.parallel_conf.device_name():
        parallel_conf.add_device_name(device_name)
    hierarchy = shape_proto_cfg.ShapeProto()
    for dim in parallel_desc_symbol.hierarchy:
        hierarchy.add_dim(dim)
    assert hierarchy.dim_size() > 0
    parallel_conf.mutable_hierarchy().CopyFrom(hierarchy)
    if builder is None:
        return oneflow_api.PlacementSymbol(parallel_desc_symbol.symbol_id,
                                           parallel_conf)
    else:
        return builder.GetParallelDescSymbol(parallel_conf)
예제 #6
0
    def Attr(self, attr_name, attr_value, attr_type_name=None):
        r"""Set value of op's attribute.

        Args:
            attr_name (str): attribute name of op
            attr_value (Any): attribute value of op

        Raises:
            ValueError: raised when value is not idential to op's attribute type.

        Returns:
            [type]: [description]
        """
        if attr_type_name is not None:
            print(
                """WARNING: Argument 'attr_type_name' of UserOpConfBuilder.Attr has been deprecated. Please remove it.

            For instance:
                -     .Attr("out_num", out_num, "AttrTypeInt64")
                +     .Attr("out_num", out_num)
                        """
            )
            print(traceback.format_stack()[-2])

        attribute = user_op_attr_cfg.AttrValue()
        assert isinstance(attr_name, str)
        assert self._op_type_name is not None
        attr_type = oneflow_api.GetUserOpAttrType(self._op_type_name, attr_name)
        if attr_type == user_op_attr_cfg.kAtInt32:
            assert isinstance(attr_value, int)
            attribute.set_at_int32(attr_value)
        elif attr_type == user_op_attr_cfg.kAtInt64:
            assert isinstance(attr_value, int)
            attribute.set_at_int64(attr_value)
        elif attr_type == user_op_attr_cfg.kAtBool:
            assert isinstance(attr_value, bool)
            attribute.set_at_bool(attr_value)
        elif attr_type == user_op_attr_cfg.kAtFloat:
            assert isinstance(attr_value, (float, int))
            attribute.set_at_float(attr_value)
        elif attr_type == user_op_attr_cfg.kAtDouble:
            assert isinstance(attr_value, (float, int))
            attribute.set_at_double(attr_value)
        elif attr_type == user_op_attr_cfg.kAtString:
            assert isinstance(attr_value, str)
            attribute.set_at_string(attr_value)
        elif attr_type == user_op_attr_cfg.kAtShape:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_shape = attribute.mutable_at_shape()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_shape.add_dim(x)
        elif attr_type == user_op_attr_cfg.kAtDataType:
            assert attr_value in oneflow.dtypes()
            attr_value = oneflow_api.deprecated.GetProtoDtype4OfDtype(attr_value)
            assert isinstance(attr_value, int)
            attribute.set_at_data_type(data_type_cfg.DataType(attr_value))
        elif attr_type == user_op_attr_cfg.kAtListInt32:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_list_int32.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListInt64:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64()
            for x in attr_value:
                assert isinstance(x, int)
                attribute_mutable_at_list_int64.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListFloat:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_float = attribute.mutable_at_list_float()
            for x in attr_value:
                assert isinstance(x, (float, int))
                attribute_mutable_at_list_float.add_val(x)
        elif attr_type == user_op_attr_cfg.kAtListDataType:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type()
            for x in attr_value:
                assert x in oneflow.dtypes()
                x = oneflow_api.deprecated.GetProtoDtype4OfDtype(x)
                assert isinstance(x, int)
                attribute_mutable_at_list_data_type.add_val(data_type_cfg.DataType(x))
        elif attr_type == user_op_attr_cfg.kAtListShape:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_shape = (
                attribute.mutable_at_list_shape().mutable_val()
            )
            for x in attr_value:
                assert isinstance(x, (tuple, list))
                shape = shape_cfg.ShapeProto()
                for dim in x:
                    assert isinstance(dim, int)
                    shape.add_dim(dim)
                attribute_mutable_at_list_shape.Add().CopyFrom(shape)
        elif attr_type == user_op_attr_cfg.kAtListString:
            assert isinstance(attr_value, (tuple, list))
            attribute_mutable_at_list_string = attribute.mutable_at_list_string()
            for x in attr_value:
                assert isinstance(x, str)
                attribute_mutable_at_list_string.add_val(x)
        else:
            raise ValueError("Invalid op attribute type {}".format(attr_type))

        self._builder.attr(attr_name, attribute)
        return self
예제 #7
0
def convert_to_user_attr_value(op_type_name, attr_name, attr_value):
    attribute = user_op_attr_cfg.AttrValue()
    assert isinstance(attr_name, str)
    attr_type = oneflow_api.GetUserOpAttrType(op_type_name, attr_name)
    if attr_type == user_op_attr_cfg.kAtInt32:
        assert isinstance(attr_value, int)
        attribute.set_at_int32(attr_value)
    elif attr_type == user_op_attr_cfg.kAtInt64:
        assert isinstance(attr_value, int)
        attribute.set_at_int64(attr_value)
    elif attr_type == user_op_attr_cfg.kAtBool:
        assert isinstance(attr_value, bool)
        attribute.set_at_bool(attr_value)
    elif attr_type == user_op_attr_cfg.kAtFloat:
        assert isinstance(attr_value, (float, int))
        attribute.set_at_float(attr_value)
    elif attr_type == user_op_attr_cfg.kAtDouble:
        assert isinstance(attr_value, (float, int))
        attribute.set_at_double(attr_value)
    elif attr_type == user_op_attr_cfg.kAtString:
        assert isinstance(attr_value, str)
        attribute.set_at_string(attr_value)
    elif attr_type == user_op_attr_cfg.kAtShape:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_shape = attribute.mutable_at_shape()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_shape.add_dim(x)
    elif attr_type == user_op_attr_cfg.kAtDataType:
        assert attr_value in oneflow.dtypes()
        attr_value = oneflow_api.deprecated.GetProtoDtype4OfDtype(attr_value)
        assert isinstance(attr_value, int)
        attribute.set_at_data_type(data_type_cfg.DataType(attr_value))
    elif attr_type == user_op_attr_cfg.kAtListInt32:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_list_int32.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListInt64:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64()
        for x in attr_value:
            assert isinstance(x, int)
            attribute_mutable_at_list_int64.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListFloat:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_float = attribute.mutable_at_list_float()
        for x in attr_value:
            assert isinstance(x, (float, int))
            attribute_mutable_at_list_float.add_val(x)
    elif attr_type == user_op_attr_cfg.kAtListDataType:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type()
        for x in attr_value:
            assert x in oneflow.dtypes()
            x = oneflow_api.deprecated.GetProtoDtype4OfDtype(x)
            assert isinstance(x, int)
            attribute_mutable_at_list_data_type.add_val(data_type_cfg.DataType(x))
    elif attr_type == user_op_attr_cfg.kAtListShape:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_shape = (
            attribute.mutable_at_list_shape().mutable_val()
        )
        for x in attr_value:
            assert isinstance(x, (tuple, list))
            shape = shape_cfg.ShapeProto()
            for dim in x:
                assert isinstance(dim, int)
                shape.add_dim(dim)
            attribute_mutable_at_list_shape.Add().CopyFrom(shape)
    elif attr_type == user_op_attr_cfg.kAtListString:
        assert isinstance(attr_value, (tuple, list))
        attribute_mutable_at_list_string = attribute.mutable_at_list_string()
        for x in attr_value:
            assert isinstance(x, str)
            attribute_mutable_at_list_string.add_val(x)
    else:
        raise ValueError("Invalid op attribute type {}".format(attr_type))
    return attribute