def slicewise(self, fn, *inputs): """Execute a function in parallel on all slices. Args: fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors. *inputs: a list of inputs. Each input is either a LaidOutTensor or is convertible to a tf.Tensor. Returns: a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple. """ if fn == tf.add: assert len(inputs) == 2 if isinstance(inputs[0], mtf.LazyAllreduceSum): # sum of LazyAllreduceSum (keep delaying the allreduce) return inputs[0] + inputs[1] # convert all inputs to LaidOutTensor where possible inputs = mtf.convert_args_to_laid_out_tensors(inputs) ret = fn(*[ x.one_slice if isinstance(x, self.LaidOutTensor) else x for x in inputs ]) if isinstance(ret, tuple): return tuple([self.LaidOutTensor([t]) for t in ret]) else: return self.LaidOutTensor([ret])
def to_tf_tensor(self, *inputs): inputs_save = inputs[1] inputs = mtf.convert_args_to_laid_out_tensors(inputs) inputs = [x.tensor_list if isinstance(x, self.LaidOutTensor) else [x] * len(self.devices) for x in inputs] if inputs[1][0].name == "logits/add_1/parallel_0_1/Add:0": return inputs try: if inputs[0][0].shape[-1] != inputs[1][0].shape[-1]: tensor_axis = len(inputs[0][0].shape)-1 ret = mtf.parallel( ["GPU:0"], tf.concat, [inputs[0]], axis=[tensor_axis] * len(["GPU:0"])) slices = ret else: return inputs_save for i in range(0,len(slices)): while slices[i].shape[0] == 3 and len(slices[i].shape) < 5: slices[i] = tf.reduce_mean(slices[i], axis=[0]) return slices except: return inputs_save return inputs
def to_tf_tensor2(self, *inputs): inputs_save = inputs[1] inputs = mtf.convert_args_to_laid_out_tensors(inputs) inputs = [x.tensor_list if isinstance(x, self.LaidOutTensor) else [x] * len(self.devices) for x in inputs] tensor_axis = len(inputs[0][0].shape)-1 ret = mtf.parallel( ["GPU:0"], tf.concat, [inputs[0]], axis=[tensor_axis] * len(["GPU:0"])) return ret
def slicewise(self, fn, *inputs): """Execute a function in parallel on all slices. Args: fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors. *inputs: a list of inputs. Each input is either a LaidOutTensor or is convertible to a tf.Tensor. Returns: a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple. """ # convert all inputs to LaidOutTensor where possible inputs = mtf.convert_args_to_laid_out_tensors(inputs) ret = fn(*[ x.one_slice if isinstance(x, self.LaidOutTensor) else x for x in inputs]) if isinstance(ret, tuple): return tuple([self.LaidOutTensor([t]) for t in ret]) else: return self.LaidOutTensor([ret])
def slicewise(self, fn, *inputs): """Execute a function in parallel on all slices. Args: fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors. *inputs: a list of inputs. Each input is either a LaidOutTensor or is convertible to a tf.Tensor. Returns: a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple. """ # convert all inputs to LaidOutTensor where possible inputs = mtf.convert_args_to_laid_out_tensors(inputs) inputs = [x.tensor_list if isinstance(x, self.LaidOutTensor) else [x] * len(self.devices) for x in inputs] ret = mtf.parallel(self.devices, fn, *inputs) if isinstance(ret[0], tuple): ret = mtf.transpose_list_of_lists(ret) return tuple([self.LaidOutTensor(t) for t in ret]) else: return self.LaidOutTensor(ret)