Ejemplo n.º 1
0
def _process_list_for_saving(l: list) -> list:
    """
    The yaml.dump function can't save Tensors, ndarrays, or callables, so we cast them to types it can save.

    :param l: list containing parameters to save
    :return: list with values processable by yaml.dump
    """
    copy = deepcopy(l)  # do not mutate the input
    for i, item in enumerate(copy):
        # Check the values of the list
        if isinstance(item, (to.Tensor, np.ndarray)):
            # Save Tensors as lists
            copy[i] = item.tolist()
        elif isinstance(item, np.float64):
            # PyYAML can not save numpy floats
            copy[i] = float(item)
        elif isinstance(item, nn.Module):
            # Only save the class name as a sting
            copy[i] = get_class_name(item)
        elif callable(item):
            # Only save function name as a sting
            try:
                copy[i] = str(item)
            except AttributeError:
                copy[i] = item.__name__
        elif isinstance(item, dict):
            # If the value is another dict, recursively go through this one
            copy[i] = _process_dict_for_saving(item)
        elif isinstance(item, list):
            # If the value is a list, recursively go through this one
            copy[i] = _process_list_for_saving(item)
        elif item is None:
            copy[i] = 'None'
    return copy
Ejemplo n.º 2
0
def _process_dict_for_saving(d: dict) -> dict:
    """
    The yaml.dump function can't save Tensors, ndarrays, or callables, so we cast them to types it can save.

    :param d: dict containing parameters to save
    :return: dict with values processable by yaml.dump
    """
    copy = deepcopy(d)  # do not mutate the input
    for k, v in copy.items():
        # Check the values of the dict
        if isinstance(v, (to.Tensor, np.ndarray)):
            # Save Tensors as lists
            copy[k] = v.tolist()
        elif isinstance(v, np.float64):
            # PyYAML can not save numpy floats
            copy[k] = float(v)
        elif isinstance(v, nn.Module):
            # Only save the class name as a sting
            copy[k] = get_class_name(v)
        elif callable(v):
            # Only save function name as a sting
            try:
                copy[k] = str(v)
            except AttributeError:
                try:
                    copy[k] = get_class_name(v)
                except Exception:
                    copy[k] = v.__name__
        elif isinstance(v, dict):
            # If the value is another dict, recursively go through this one
            copy[k] = _process_dict_for_saving(v)
        elif isinstance(v, list):
            # If the value is a list, recursively go through this one
            copy[k] = _process_list_for_saving(v)
        elif v is None:
            copy[k] = 'None'
    return copy
Ejemplo n.º 3
0
def _process_list_for_saving(l: [list, tuple]) -> [list, tuple]:
    """
    The yaml.dump function can't save PyTorch tensors, numpy arrays, or callables, so we cast them to types it can save.

    :param l: list or tuple containing parameters to save
    :return: list or tuple with values processable by yaml.dump
    """
    # Do not mutate the input. Convert tuple to list to make elements mutable
    copy = list(deepcopy(l))
    for i, item in enumerate(copy):
        # Check the values of the list
        if isinstance(item, (to.Tensor, np.ndarray)):
            # Save Tensors as lists
            copy[i] = item.tolist()
        elif isinstance(item, np.float64):
            # PyYAML can not save numpy floats
            copy[i] = float(item)
        elif isinstance(item, nn.Module):
            # Only save the class name as a sting
            copy[i] = get_class_name(item)
        elif callable(item):
            # Only save function name as a sting
            try:
                copy[i] = str(item)
            except AttributeError:
                copy[i] = item.__name__
        elif isinstance(item, dict):
            # If the value is another dict, recursively go through this one
            copy[i] = _process_dict_for_saving(item)
        elif isinstance(item, (list, tuple)):
            # If the value is a list or tuple, recursively go through this one
            copy[i] = _process_list_for_saving(item)
        elif item is None:
            copy[i] = "None"
    # The returned object should be of the same type as the input
    return copy if isinstance(l, list) else tuple(copy)