def _GetInterfaceBlobObject(builder, op_name): sess = session_ctx.GetDefaultSession() if oneflow._oneflow_internal.EagerExecutionEnabled(): return sess.var_name2var_blob[op_name].blob_object sess = session_ctx.GetDefaultSession() op_attribute = sess.OpAttribute4InterfaceOpName(op_name) cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString( str(op_attribute)) parallel_conf = sess.ParallelConf4LazyInterfaceOpName(op_name) if not isinstance( parallel_conf, oneflow._oneflow_internal.oneflow.core.job.placement.ParallelConf): parallel_conf_cfg = placement_cfg.ParallelConf() parallel_conf_cfg.set_device_tag(parallel_conf.device_tag) for device_name in parallel_conf.device_name: parallel_conf_cfg.add_device_name(device_name) if parallel_conf.HasField("hierarchy"): hierarchy = shape_proto_cfg.ShapeProto() for dim in parallel_conf.hierarchy.dim: hierarchy.add_dim(dim) assert hierarchy.dim_size() > 0 parallel_conf_cfg.mutable_hierarchy().CopyFrom(hierarchy) parallel_conf = parallel_conf_cfg blob_object = builder.MakeLazyRefBlobObject(op_name, cfg_op_attribute, parallel_conf) return blob_object
def _inferface_blob_conf_proto_to_cfg(inferface_blob_conf_proto, mut_inferface_blob_conf_cfg): assert isinstance(inferface_blob_conf_proto, interface_blob_conf_proto.InterfaceBlobConf) assert isinstance(mut_inferface_blob_conf_cfg, interface_blob_conf_proto_cfg.InterfaceBlobConf) shape = shape_proto_cfg.ShapeProto() for dim in inferface_blob_conf_proto.shape.dim: shape.add_dim(dim) mut_inferface_blob_conf_cfg.mutable_shape().CopyFrom(shape) dtype = dtype_proto_cfg.DataType(int(inferface_blob_conf_proto.data_type)) mut_inferface_blob_conf_cfg.set_data_type(dtype) if inferface_blob_conf_proto.HasField("parallel_distribution"): # TODO(guoran): Process Nd sbp, parallel_distribution_cfg CopyFrom parallel_distribution_proto assert len( inferface_blob_conf_proto.parallel_distribution.sbp_parallel) == 1 sbp_proto = inferface_blob_conf_proto.parallel_distribution.sbp_parallel[ 0] if sbp_proto.HasField("split_parallel"): split_axis = sbp_proto.split_parallel.axis sbp = sbp_parallel_cfg.SbpParallel() sbp.mutable_split_parallel().set_axis(split_axis) mut_inferface_blob_conf_cfg.mutable_parallel_distribution( ).mutable_sbp_parallel().Add().CopyFrom(sbp) mut_inferface_blob_conf_cfg.set_is_dynamic( inferface_blob_conf_proto.is_dynamic)
def ReplaceDeviceTag(parallel_desc_symbol, device_tag, builder=None): assert parallel_desc_symbol.device_tag != device_tag parallel_conf = placement_cfg.ParallelConf() parallel_conf.set_device_tag(device_tag) for device_name in parallel_desc_symbol.parallel_conf.device_name(): parallel_conf.add_device_name(device_name) hierarchy = shape_proto_cfg.ShapeProto() for dim in parallel_desc_symbol.hierarchy: hierarchy.add_dim(dim) assert hierarchy.dim_size() > 0 parallel_conf.mutable_hierarchy().CopyFrom(hierarchy) if builder is None: return oneflow._oneflow_internal.PlacementSymbol( parallel_desc_symbol.symbol_id, parallel_conf) else: return builder.GetParallelDescSymbol(parallel_conf)
def convert_to_user_attr_value(op_type_name, attr_name, attr_value): attribute = user_op_attr_cfg.AttrValue() assert isinstance(attr_name, str) attr_type = oneflow._oneflow_internal.GetUserOpAttrType( op_type_name, attr_name) if attr_type == user_op_attr_cfg.kAtInt32: assert isinstance(attr_value, int) attribute.set_at_int32(attr_value) elif attr_type == user_op_attr_cfg.kAtInt64: assert isinstance(attr_value, int) attribute.set_at_int64(attr_value) elif attr_type == user_op_attr_cfg.kAtBool: assert isinstance(attr_value, bool) attribute.set_at_bool(attr_value) elif attr_type == user_op_attr_cfg.kAtFloat: assert isinstance(attr_value, (float, int)) attribute.set_at_float(attr_value) elif attr_type == user_op_attr_cfg.kAtDouble: assert isinstance(attr_value, (float, int)) attribute.set_at_double(attr_value) elif attr_type == user_op_attr_cfg.kAtString: assert isinstance(attr_value, str) attribute.set_at_string(attr_value) elif attr_type == user_op_attr_cfg.kAtShape: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_shape = attribute.mutable_at_shape() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_shape.add_dim(x) elif attr_type == user_op_attr_cfg.kAtDataType: assert attr_value in oneflow.dtypes() attr_value = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype( attr_value) assert isinstance(attr_value, int) attribute.set_at_data_type(data_type_cfg.DataType(attr_value)) elif attr_type == user_op_attr_cfg.kAtListInt32: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_list_int32.add_val(x) elif attr_type == user_op_attr_cfg.kAtListInt64: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_list_int64.add_val(x) elif attr_type == user_op_attr_cfg.kAtListFloat: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_float = attribute.mutable_at_list_float() for x in attr_value: assert isinstance(x, (float, int)) attribute_mutable_at_list_float.add_val(x) elif attr_type == user_op_attr_cfg.kAtListDataType: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type( ) for x in attr_value: assert x in oneflow.dtypes() x = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(x) assert isinstance(x, int) attribute_mutable_at_list_data_type.add_val( data_type_cfg.DataType(x)) elif attr_type == user_op_attr_cfg.kAtListShape: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_shape = ( attribute.mutable_at_list_shape().mutable_val()) for x in attr_value: assert isinstance(x, (tuple, list)) shape = shape_cfg.ShapeProto() for dim in x: assert isinstance(dim, int) shape.add_dim(dim) attribute_mutable_at_list_shape.Add().CopyFrom(shape) elif attr_type == user_op_attr_cfg.kAtListString: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_string = attribute.mutable_at_list_string() for x in attr_value: assert isinstance(x, str) attribute_mutable_at_list_string.add_val(x) else: raise ValueError("Invalid op attribute type {}".format(attr_type)) return attribute
def Attr(self, attr_name, attr_value, attr_type_name=None): """Set value of op's attribute. Args: attr_name (str): attribute name of op attr_value (Any): attribute value of op Raises: ValueError: raised when value is not idential to op's attribute type. Returns: [type]: [description] """ if attr_type_name != None: print( 'WARNING: Argument \'attr_type_name\' of UserOpConfBuilder.Attr has been deprecated. Please remove it.\n\n For instance:\n - .Attr("out_num", out_num, "AttrTypeInt64")\n + .Attr("out_num", out_num)\n ' ) print(traceback.format_stack()[-2]) attribute = user_op_attr_cfg.AttrValue() assert isinstance(attr_name, str) attr_type = oneflow._oneflow_internal.GetUserOpAttrType( self.user_op_.op_conf_.user_conf.op_type_name, attr_name) if attr_type == user_op_attr_cfg.kAtInt32: assert isinstance(attr_value, int) attribute.set_at_int32(attr_value) elif attr_type == user_op_attr_cfg.kAtInt64: assert isinstance(attr_value, int) attribute.set_at_int64(attr_value) elif attr_type == user_op_attr_cfg.kAtBool: assert isinstance(attr_value, bool) attribute.set_at_bool(attr_value) elif attr_type == user_op_attr_cfg.kAtFloat: assert isinstance(attr_value, (float, int)) attribute.set_at_float(attr_value) elif attr_type == user_op_attr_cfg.kAtDouble: assert isinstance(attr_value, (float, int)) attribute.set_at_double(attr_value) elif attr_type == user_op_attr_cfg.kAtString: assert isinstance(attr_value, str) attribute.set_at_string(attr_value) elif attr_type == user_op_attr_cfg.kAtShape: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_shape = attribute.mutable_at_shape() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_shape.add_dim(x) elif attr_type == user_op_attr_cfg.kAtDataType: assert attr_value in flow.dtypes() attr_value = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype( attr_value) assert isinstance(attr_value, int) attribute.set_at_data_type(data_type_cfg.DataType(attr_value)) elif attr_type == user_op_attr_cfg.kAtListInt32: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_int32 = attribute.mutable_at_list_int32() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_list_int32.add_val(x) elif attr_type == user_op_attr_cfg.kAtListInt64: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_int64 = attribute.mutable_at_list_int64() for x in attr_value: assert isinstance(x, int) attribute_mutable_at_list_int64.add_val(x) elif attr_type == user_op_attr_cfg.kAtListFloat: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_float = attribute.mutable_at_list_float() for x in attr_value: assert isinstance(x, (float, int)) attribute_mutable_at_list_float.add_val(x) elif attr_type == user_op_attr_cfg.kAtListDataType: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_data_type = attribute.mutable_at_list_data_type( ) for x in attr_value: assert x in flow.dtypes() x = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype( x) assert isinstance(x, int) attribute_mutable_at_list_data_type.add_val( data_type_cfg.DataType(x)) elif attr_type == user_op_attr_cfg.kAtListShape: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_shape = ( attribute.mutable_at_list_shape().mutable_val()) for x in attr_value: assert isinstance(x, (tuple, list)) shape = shape_cfg.ShapeProto() for dim in x: assert isinstance(dim, int) shape.add_dim(dim) attribute_mutable_at_list_shape.Add().CopyFrom(shape) elif attr_type == user_op_attr_cfg.kAtListString: assert isinstance(attr_value, (tuple, list)) attribute_mutable_at_list_string = attribute.mutable_at_list_string( ) for x in attr_value: assert isinstance(x, str) attribute_mutable_at_list_string.add_val(x) else: raise ValueError("Invalid op attribute type {}".format(attr_type)) self.user_op_.op_conf_.user_conf.attr[attr_name].CopyFrom( text_format.Parse(str(attribute), attr_value_pb.AttrValue())) return self