def expand_to(tensors, dim, num): if is_tensor(tensors): lst = list(tensors.shape) if num == lst[dim]: return tensors lst[dim] = num - lst[dim] if torch.__version__ >= '0.4': print('pytorch 0.4 seems to combine tensor and variable. ' 'Please implement the logic for expand_at manually.') raise NotImplementedError data = tensors if isinstance(data, Variable): data = data.data padding_part = type(data)(*lst) padding_part.fill_(0) if isinstance(tensors, Variable): padding_part = Variable(padding_part) return torch.cat([tensors, padding_part], dim=dim) elif is_sequence(tensors): return [Tuplize.expand_to(t, dim, num) for t in tensors] else: raise ValueError
def len(tensors, dim): if is_tensor(tensors): return tensors.shape[dim] elif is_sequence(tensors): return tensors[0].shape[dim] else: raise ValueError
def flatten(tensors, s, e): if is_tensor(tensors): return tensors.view(tensors.shape[:s] + (-1, ) + tensors.shape[e:]) elif is_sequence(tensors): return [ t.view(t.shape[:s] + (-1, ) + t.shape[e:]) for t in tensors ] else: raise ValueError
def tensor(self, shape): if self._level == 1: if is_sequence(self.v): v = tensor_type[type(self.v[0])](self.v) else: v = self.v return tur.expand_to(v, 0, shape[0]) elif self._level == 2: if is_sequence(self.v): if is_sequence(self.v[0]): ttype = tensor_type[type(self.v[0][0])] v = [ttype(x) for x in self.v] else: v = self.v v = torch.stack(tur.expand_to(v, 0, shape[1]), 0) else: v = self.v return tur.expand_to(v, 0, shape[0])
def __init__(self, v, level=None): self.v = v self._level = 0 if level is None: p = v while is_sequence(p): p = p[0] self._level += 1 else: self._level = level
def func(name, tensors, *args, **kwargs): if hasattr(torch.Tensor, name): if is_tensor(tensors): return getattr(tensors, name)(*args, **kwargs) elif is_sequence(tensors): return [getattr(t, name)(*args, **kwargs) for t in tensors] else: raise ValueError else: if is_tensor(tensors): return getattr(torch, name)(tensors, *args, **kwargs) elif is_sequence(tensors): if is_sequence(tensors[0]): return [ getattr(torch, name)(t, *args, **kwargs) for t in list(zip(*tensors)) ] else: return getattr(torch, name)(tensors, *args, **kwargs) else: raise ValueError
def __init__(self, white_list=None, black_list=None): if not is_sequence(white_list) or isinstance(white_list, tuple): if not isinstance(white_list, tuple): white_list = (white_list, ) white_list = [white_list] if not is_sequence(black_list) or isinstance(black_list, tuple): if black_list is not None: if not isinstance(black_list, tuple): black_list = (black_list, ) black_list = [black_list] self.white = {} self.black = {} for rule in white_list: self.add_white_rule(rule) if black_list is None: return for rule in black_index: self.add_black_rule(rule)
def views(tensors, *args): if is_tensor(tensors): shape_to = [] for p in args: if isinstance(p, slice): shape_to += list(tensors.shape[p]) elif is_sequence(p): shape_to += p else: shape_to += [p] return tensors.view(shape_to) elif isinstance(tensors, collections.Sequence): return [Tuplize.views(t, *args) for t in tensors] else: raise ValueError
def get_handle(roots, capture, white_list=None, black_list=None): """Function to get handle. Args: roots (list, class): The capture will start at each entry of `roots`. If `roots` is a class, it will be automatically converted to [`roots`]. capture: The type to be captured. white_list (optional): A list contains all accepted index paths, where a index path is a tuple specifying the level-by-level index, e.g., ('name', None, 0) means that handles are produced by any element with type `capture` in root['name'][*][0]([*])*, where the last bracket follows regular exp. A str index can also be used to get class attr. black_list (optional): A list contains all rejected index paths, overriding `white_list`. We only accept recursive dict, list, tuple and class with rules: dict: any key or None. list: must be None and will always go into the first entry only. tuple: any index or None. class: any attrname or None. Any real index path cannot be accepted by more than one in white list. Returns: handle(s): A handle is a tuple (root, white_list, black_list), where white_list are those EXACT captured paths, e.g., if dataset.data is a list of str, then get_handle(dataset, str, None) produces (root, [(data, None), ], None) as a handle. If `roots` contains only one element, one handle is simply returned. """ capturer = Capturer(capture, Tracker(white_list, black_list)) assert is_sequence(roots) or hasattr(roots, '__dict__') if hasattr(roots, '__dict__'): roots = [roots] handles = [(root, capturer(root), black_list) for root in roots] if len(handles) == 1: return handles[0] else: return handles
def denumericalize(self, v): if is_sequence(v): return [self.denumericalize(x) for x in v] else: return self.itos[v]
def recursive_add(v): if is_sequence(v[0]): for x in v: recursive_add(x) else: c.update(v)
def recursive_numericalize(self, v): for i, x in enumerate(v): if is_sequence(x): self.recursive_numericalize(x) elif isinstance(x, str): v[i] = self.stoi.get(x, 1)