示例#1
0
def convert_str_arr_to_int(array: Union[List, np.ndarray], label: ClassLabel):
    for i, elem in enumerate(array):
        if isinstance(elem, str):
            try:
                array[i] = label.str2int(elem)
            except KeyError:
                raise ClassLabelValueError(label.names, elem)
    if isinstance(array, np.ndarray) and array.dtype.type is np.str_:
        array = np.asarray(array, dtype="int8")
    return array
示例#2
0
def check_class_label(value: Union[np.ndarray, list], label: ClassLabel):
    """Check if value can be assigned to predefined ClassLabel"""
    if not isinstance(value, Iterable) or isinstance(value, str):
        assign_class_labels = [value]
    else:
        assign_class_labels = value
    for i, assign_class_label in enumerate(assign_class_labels):
        if isinstance(assign_class_label, str):
            try:
                assign_class_labels[i] = label.str2int(assign_class_label)
            except KeyError:
                raise ClassLabelValueError(label.names, assign_class_label)

    if min(assign_class_labels) < 0 or max(assign_class_labels) > label.num_classes - 1:
        raise ClassLabelValueError(range(label.num_classes - 1), assign_class_label)
    if len(assign_class_labels) == 1:
        return assign_class_labels[0]
    return assign_class_labels