def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) if isinstance(obj, DataStreams) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) # Return "unscattered" object for all GPUs. # This seems to be the cause of the issue for SentenceEmbeddings! # TODO: further investigate. return [obj for _ in target_gpus]
def scatter_map(obj): if isinstance(obj, torch.Tensor): try: return Scatter.apply(target_gpus, chunk_sizes, dim, obj) except: print('obj', obj.size()) print('dim', dim) print('chunk_sizes', chunk_sizes) quit() if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus]
def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) if isinstance(obj, ScatterableList): # In order to have precisely the same method of scattering as PyTorch we scatter # a tensor of pointers. pointers = scatter_map(obj.to_pointer_tensor()) # Then we reconstruct the lists from the pointer tensors. return [obj.from_pointer_tensor(chunk) for chunk in pointers] if isinstance(obj, tuple) and obj: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and obj: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and obj: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for _ in target_gpus]
def scatter_map(obj, chunk_sizes=None): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, chunk_sizes, dim, obj) # type:ignore if isinstance(obj, tuple) and obj: chunked_scatter_map = partial(scatter_map, chunk_sizes=chunk_sizes) return list(zip(*map(chunked_scatter_map, obj))) if isinstance(obj, list) and obj: chunked_scatter_map = partial(scatter_map, chunk_sizes=chunk_sizes) return list(map(list, zip(*map(chunked_scatter_map, obj)))) if isinstance(obj, dict) and obj: chunk_sizes = obj.get("chunk_sizes", chunk_sizes) chunked_scatter_map = partial(scatter_map, chunk_sizes=chunk_sizes) return list( map(type(obj), zip(*map(chunked_scatter_map, obj.items())))) return [obj for targets in target_gpus]
def scatter_map(obj): if isinstance(obj, torch.Tensor): return OrigScatter.apply(target_gpus, None, dim, obj) if isinstance(obj, DataContainer): if obj.cpu_only: return obj.data else: return Scatter.forward(target_gpus, obj.data) if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: out = list(map(list, zip(*map(scatter_map, obj)))) return out if isinstance(obj, dict) and len(obj) > 0: out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) return out return [obj for targets in target_gpus]
def scatter_map(obj): if isinstance(obj, torch.Tensor): result = Scatter.apply(target_gpus, None, dim, obj) return result if isinstance(obj, list) and len(obj) > 0: result = ScatterShallow.apply(target_gpus, dim, obj) return result # `inputs` is either a tuple for positional arguments or a dict for keyword arguments, # so just recursively go deeper. if isinstance(obj, tuple) and len(obj) > 0: result = list(zip(*map(scatter_map, obj))) return result if isinstance(obj, dict) and len(obj) > 0: keys_and_values = list(zip(*map(scatter_map, obj.items()))) result = list(map(type(obj), keys_and_values)) return result return [obj for targets in target_gpus]
def recursive_apply(target_gpus, dim, input): if isinstance(input, torch.Tensor): return Scatter.apply(target_gpus, None, dim, input)[0] elif isinstance(input, list): return [ ScatterShallow.recursive_apply(target_gpus, dim, i) for i in input ] elif isinstance(input, tuple): return (ScatterShallow.recursive_apply(target_gpus, dim, i) for i in input) elif isinstance(input, dict): return { k: ScatterShallow.recursive_apply(target_gpus, dim, v) for k, v in input.items() } return input
def scatter_map(obj): if isinstance(obj, Variable): # print('var') return Scatter.apply(target_gpus, None, dim, obj) assert not torch.is_tensor(obj), "Tensors not supported in scatter." if isinstance(obj, ScatterList): # print('target_gpus:', target_gpus, 'obj:', len(obj)) # assert len(obj) == len(target_gpus) chunk_size = int(ceil(float(len(obj)) / float(len(target_gpus)))) # print('scatterlist') # print (chunk_size, len(obj)) return [ obj[i * chunk_size:(i + 1) * chunk_size] for i in range(len(target_gpus)) ] if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) # print('others') return [obj for targets in target_gpus]
def scatter_map(obj): if isinstance(obj, torch.Tensor): if (len(target_gpus) == 4) and (obj.size(dim) == 22): return Scatter.apply(target_gpus, (4, 6, 6, 6), dim, obj) if (len(target_gpus) == 4) and (obj.size(dim) == 60): return Scatter.apply(target_gpus, (12, 16, 16, 16), dim, obj) elif (len(target_gpus) == 4) and (obj.size(dim) == 144): return Scatter.apply(target_gpus, (24, 40, 40, 40), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 46): return Scatter.apply(target_gpus, (4, 6, 6, 6, 6, 6, 6, 6), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 62): return Scatter.apply(target_gpus, (6, 8, 8, 8, 8, 8, 8, 8), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 94): return Scatter.apply(target_gpus, (10, 12, 12, 12, 12, 12, 12, 12), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 110): return Scatter.apply(target_gpus, (12, 14, 14, 14, 14, 14, 14, 14), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 118): return Scatter.apply(target_gpus, (13, 15, 15, 15, 15, 15, 15, 15), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 126): return Scatter.apply(target_gpus, (14, 16, 16, 16, 16, 16, 16, 16), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 134): return Scatter.apply(target_gpus, (15, 17, 17, 17, 17, 17, 17, 17), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 142): return Scatter.apply(target_gpus, (16, 18, 18, 18, 18, 18, 18, 18), dim, obj) elif (len(target_gpus) == 16) and (obj.size(dim) == 222): return Scatter.apply(target_gpus, (12, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14), dim, obj) return Scatter.apply(target_gpus, None, dim, obj) if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus]
def scatter_distrib(target_gpus, mystery, dim, obj): if issubclass(type(obj), (dist.Normal, dist.MultivariateNormal)): means = Scatter.apply(target_gpus, None, dim, obj.mean) stddev = Scatter.apply(target_gpus, None, dim, obj.stddev) return tuple( [type(obj)(means[i], stddev[i]) for i in range(len(means))])
def scatter_map(obj): map_strategy = {4: (1, 1, 1, 1), 8: (1, 2, 2, 3)} if isinstance(obj, torch.Tensor): # print("our",obj.size(dim)) if (len(target_gpus) == 4) and (obj.size(dim) == 8): return Scatter.apply(target_gpus, map_strategy[8], dim, obj) if (len(target_gpus) == 4) and (obj.size(dim) == 24): return Scatter.apply(target_gpus, (2, 7, 7, 8), dim, obj) if (len(target_gpus) == 4) and (obj.size(dim) == 12): return Scatter.apply(target_gpus, (1, 3, 4, 4), dim, obj) if (len(target_gpus) == 4) and (obj.size(dim) == 36): return Scatter.apply(target_gpus, (6, 10, 10, 10), dim, obj) if (len(target_gpus) == 4) and (obj.size(dim) == 16): return Scatter.apply(target_gpus, (2, 4, 5, 5), dim, obj) if (len(target_gpus) == 4) and (obj.size(dim) == 48): return Scatter.apply(target_gpus, (9, 13, 13, 13), dim, obj) if (len(target_gpus) == 4) and (obj.size(dim) == 22): return Scatter.apply(target_gpus, (4, 6, 6, 6), dim, obj) if (len(target_gpus) == 4) and (obj.size(dim) == 60): return Scatter.apply(target_gpus, (12, 16, 16, 16), dim, obj) elif (len(target_gpus) == 4) and (obj.size(dim) == 144): return Scatter.apply(target_gpus, (24, 40, 40, 40), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 46): return Scatter.apply(target_gpus, (4, 6, 6, 6, 6, 6, 6, 6), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 62): return Scatter.apply(target_gpus, (6, 8, 8, 8, 8, 8, 8, 8), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 94): return Scatter.apply(target_gpus, (10, 12, 12, 12, 12, 12, 12, 12), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 110): return Scatter.apply(target_gpus, (12, 14, 14, 14, 14, 14, 14, 14), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 118): return Scatter.apply(target_gpus, (13, 15, 15, 15, 15, 15, 15, 15), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 126): return Scatter.apply(target_gpus, (14, 16, 16, 16, 16, 16, 16, 16), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 134): return Scatter.apply(target_gpus, (15, 17, 17, 17, 17, 17, 17, 17), dim, obj) elif (len(target_gpus) == 8) and (obj.size(dim) == 142): return Scatter.apply(target_gpus, (16, 18, 18, 18, 18, 18, 18, 18), dim, obj) elif (len(target_gpus) == 16) and (obj.size(dim) == 222): return Scatter.apply(target_gpus, (12, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14), dim, obj) return Scatter.apply(target_gpus, None, dim, obj) if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: # return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, list) and len(obj) == 4: return [obj[0], obj[1], obj[2], obj[3]] if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus]