def recursively_map(an_object, function, is_key=False): from tools.basics import is_iterable # base case 1 (iterable but treated like a primitive) if isinstance(an_object, str): return_value = an_object # base case 2 (exists because of scalar numpy/pytorch/tensorflow objects) if hasattr(an_object, "tolist"): return_value = an_object.tolist() else: # base case 3 if not is_iterable(an_object): return_value = an_object else: if isinstance(an_object, dict): return_value = { recursively_map(each_key, function, is_key=True) : recursively_map(each_value, function) for each_key, each_value in an_object.items() } else: return_value = [ recursively_map(each, function) for each in an_object ] # convert lists to tuples so they are hashable if is_iterable(return_value) and not isinstance(return_value, dict) and not isinstance(return_value, str): return_value = tuple(return_value) return function(return_value, is_key=is_key)
def merge(old_value, new_value): # if not dict, see if it is iterable if not isinstance(new_value, collections.abc.Mapping): if is_iterable(new_value): new_value = { index: value for index, value in enumerate(new_value) } # if still not a dict, then just return the current value if not isinstance(new_value, collections.abc.Mapping): return new_value # otherwise get recursive else: # if not dict, see if it is iterable if not isinstance(old_value, collections.abc.Mapping): if is_iterable(old_value): old_value = { index: value for index, value in enumerate(old_value) } # if still not a dict if not isinstance(old_value, collections.abc.Mapping): # force it to be one old_value = {} # override each key recursively for key, updated_value in new_value.items(): old_value[key] = merge(old_value.get(key, {}), updated_value) return old_value
def to_tensor(an_object): from tools.basics import is_iterable # if already a tensor, just return if isinstance(an_object, torch.Tensor): return an_object # if scalar, wrap it with a tensor if not is_iterable(an_object): return torch.tensor(an_object) else: as_list = tuple([ each for each in an_object ]) # # check for all-scalar container # is_all_scalar = True # for each in as_list: # if is_iterable(each): # is_all_scalar = False # break # if is_all_scalar: # return torch.tensor(as_list) size_mismatch = False biggest_number_of_dimensions = 0 non_one_dimensions = None converted_data = [] # check the shapes of everything for each in as_list: tensor = to_tensor(each) converted_data.append(tensor) skipping = True each_non_one_dimensions = [] for index, each_dimension in enumerate(tensor.shape): # keep track of number of dimensions if index+1 > biggest_number_of_dimensions: biggest_number_of_dimensions += 1 if each_dimension != 1: skipping = False if skipping and each_dimension == 1: continue else: each_non_one_dimensions.append(each_dimension) # if uninitilized if non_one_dimensions is None: non_one_dimensions = list(each_non_one_dimensions) # if dimension already exists else: # make sure its the correct shape if non_one_dimensions != each_non_one_dimensions: size_mismatch = True break if size_mismatch: sizes = "\n".join([ f" {tuple(to_tensor(each).shape)}" for each in as_list]) raise Exception(f'When converting an object to a torch tensor, there was an issue with the shapes not being uniform. All shapes need to be the same, but instead the shapes were:\n {sizes}') # make all the sizes the same by filling in the dimensions with a size of one reshaped_list = [] for each in converted_data: shape = tuple(each.shape) number_of_dimensions = len(shape) number_of_missing_dimensions = biggest_number_of_dimensions - number_of_dimensions missing_dimensions_tuple = (1,)*number_of_missing_dimensions reshaped_list.append(torch.reshape(each, (*missing_dimensions_tuple, *shape))) return torch.stack(reshaped_list).type(torch.float)
def to_pure(an_object, recursion_help=None): from tools.basics import is_iterable # # infinte recursion prevention # top_level = False if recursion_help is None: top_level = True recursion_help = {} class PlaceHolder: def __init__(self, id): self.id = id def eval(self): return recursion_help[key] object_id = id(an_object) # if we've see this object before if object_id in recursion_help: # if this value is a placeholder, then it means we found a child that is equal to a parent (or equal to other ancestor/grandparent) if isinstance(recursion_help[object_id], PlaceHolder): return recursion_help[object_id] else: # if its not a placeholder, then we already have cached the output return recursion_help[object_id] # if we havent seen the object before, give it a placeholder while it is being computed else: recursion_help[object_id] = PlaceHolder(object_id) parents_of_placeholders = set() # # main compute # return_value = None # base case 1 (iterable but treated like a primitive) if isinstance(an_object, str): return_value = an_object # base case 2 (exists because of scalar numpy/pytorch/tensorflow objects) elif hasattr(an_object, "tolist"): return_value = an_object.tolist() else: # base case 3 if not is_iterable(an_object): return_value = an_object else: if isinstance(an_object, dict): return_value = { to_pure(each_key, recursion_help) : to_pure(each_value, recursion_help) for each_key, each_value in an_object.items() } else: return_value = [ to_pure(each, recursion_help) for each in an_object ] # convert iterables to tuples so they are hashable if is_iterable(return_value) and not isinstance(return_value, dict) and not isinstance(return_value, str): return_value = tuple(return_value) # update the cache/log with the real value recursion_help[object_id] = return_value # # handle placeholders # if is_iterable(return_value): # check if this value has any placeholder children children = return_value if not isinstance(return_value, dict) else [ *return_value.keys(), *return_value.values() ] for each in children: if isinstance(each, PlaceHolder): parents_of_placeholders.add(return_value) break # convert all the placeholders into their final values if top_level == True: for each_parent in parents_of_placeholders: iterator = enumerate(each_parent) if not isinstance(each_parent, dict) else each_parent.items() for each_key, each_value in iterator: if isinstance(each_parent[each_key], PlaceHolder): each_parent[each_key] = each_parent[each_key].eval() # if the key is a placeholder if isinstance(each_key, PlaceHolder): value = each_parent[each_key] del each_parent[each_key] each_parent[each_key.eval()] = value # finally return the value return return_value
def is_like_generator(thing): return is_iterable(thing) and not isinstance(thing, (str, bytes))