Example #1
0
    def slicewise(self, fn, *inputs):
        """Execute a function in parallel on all slices.

    Args:
      fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors.
      *inputs: a list of inputs.  Each input is either a LaidOutTensor or
        is convertible to a tf.Tensor.
    Returns:
      a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple.
    """
        if fn == tf.add:
            assert len(inputs) == 2
            if isinstance(inputs[0], mtf.LazyAllreduceSum):
                # sum of LazyAllreduceSum (keep delaying the allreduce)
                return inputs[0] + inputs[1]
        # convert all inputs to LaidOutTensor where possible
        inputs = mtf.convert_args_to_laid_out_tensors(inputs)
        ret = fn(*[
            x.one_slice if isinstance(x, self.LaidOutTensor) else x
            for x in inputs
        ])
        if isinstance(ret, tuple):
            return tuple([self.LaidOutTensor([t]) for t in ret])
        else:
            return self.LaidOutTensor([ret])
Example #2
0
  def to_tf_tensor(self, *inputs):

    inputs_save = inputs[1]
    inputs = mtf.convert_args_to_laid_out_tensors(inputs)
    inputs = [x.tensor_list if isinstance(x, self.LaidOutTensor)
              else [x] * len(self.devices) for x in inputs]

    if inputs[1][0].name == "logits/add_1/parallel_0_1/Add:0":
      return inputs

    try:
      if inputs[0][0].shape[-1] != inputs[1][0].shape[-1]:
        tensor_axis = len(inputs[0][0].shape)-1
        ret = mtf.parallel(
            ["GPU:0"], tf.concat, [inputs[0]],
            axis=[tensor_axis] * len(["GPU:0"]))
        slices = ret
      else:
        return inputs_save

      for i in range(0,len(slices)):
        while slices[i].shape[0] == 3 and len(slices[i].shape) < 5:
          slices[i] = tf.reduce_mean(slices[i], axis=[0])
      return slices
    except:
      return inputs_save
    return inputs
Example #3
0
  def to_tf_tensor2(self, *inputs):
    inputs_save = inputs[1]
    inputs = mtf.convert_args_to_laid_out_tensors(inputs)
    inputs = [x.tensor_list if isinstance(x, self.LaidOutTensor)
              else [x] * len(self.devices) for x in inputs]
    tensor_axis = len(inputs[0][0].shape)-1
    ret = mtf.parallel(
        ["GPU:0"], tf.concat, [inputs[0]],
        axis=[tensor_axis] * len(["GPU:0"]))

    return ret
Example #4
0
  def slicewise(self, fn, *inputs):
    """Execute a function in parallel on all slices.

    Args:
      fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors.
      *inputs: a list of inputs.  Each input is either a LaidOutTensor or
        is convertible to a tf.Tensor.
    Returns:
      a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple.
    """
    # convert all inputs to LaidOutTensor where possible
    inputs = mtf.convert_args_to_laid_out_tensors(inputs)
    ret = fn(*[
        x.one_slice if isinstance(x, self.LaidOutTensor) else x
        for x in inputs])
    if isinstance(ret, tuple):
      return tuple([self.LaidOutTensor([t]) for t in ret])
    else:
      return self.LaidOutTensor([ret])
Example #5
0
  def slicewise(self, fn, *inputs):
    """Execute a function in parallel on all slices.

    Args:
      fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors.
      *inputs: a list of inputs.  Each input is either a LaidOutTensor or
        is convertible to a tf.Tensor.
    Returns:
      a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple.
    """
    # convert all inputs to LaidOutTensor where possible
    inputs = mtf.convert_args_to_laid_out_tensors(inputs)
    inputs = [x.tensor_list if isinstance(x, self.LaidOutTensor)
              else [x] * len(self.devices) for x in inputs]
    ret = mtf.parallel(self.devices, fn, *inputs)
    if isinstance(ret[0], tuple):
      ret = mtf.transpose_list_of_lists(ret)
      return tuple([self.LaidOutTensor(t) for t in ret])
    else:
      return self.LaidOutTensor(ret)