Exemple #1
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   total = 0
   for _ in range(p.repeat):
     for sub in p.sub:
       tf.logging.vlog(1, '  seq abs fprop %s %s %d %s', sub.name, sub.cls,
                       len(args), str(args))
       meta = sub.cls.FPropMeta(sub, *args)
       py_utils.CheckShapes(meta.out_shapes)
       total += meta.flops
       args = meta.out_shapes
   return py_utils.NestedMap(flops=total, out_shapes=args)
Exemple #2
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   input_shapes = [
       None if arg is None else tshape.Shape(arg.get_shape().as_list()[1:])
       for arg in args
   ]
   meta = p.body.cls.FPropMeta(p.body, *input_shapes)
   py_utils.CheckShapes(meta.out_shapes)
   total = meta.flops * p.repeat
   out_shapes = [
       None if s is None else tshape.Shape([p.repeat] + s[:])
       for s in meta.out_shapes
   ]
   return py_utils.NestedMap(flops=total, out_shapes=tuple(out_shapes))
    def FPropMeta(cls, p, *args):
        py_utils.CheckShapes(args)
        total = 0
        outputs = []
        for sub in p.sub:
            meta = sub.cls.FPropMeta(sub, *args)
            py_utils.CheckShapes(meta.out_shapes)
            total += meta.flops
            outputs.append(meta.out_shapes)

        meta = p.reduce_meta(outputs)
        py_utils.CheckShapes(meta.out_shapes)
        meta.flops += total
        return meta
Exemple #4
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   if p.n > 1:
     out_shapes = args[:p.n]
   else:
     out_shapes = (args[0],)
   return py_utils.NestedMap(flops=0, out_shapes=out_shapes)
 def FPropMeta(cls, p, inputs, *args):
     py_utils.CheckShapes((inputs, ))
     input_shape_list = inputs.as_list()
     out_shape = tf.TensorShape(input_shape_list[0:1] +
                                input_shape_list[3:])
     return py_utils.NestedMap(flops=inputs.num_elements(),
                               out_shapes=(out_shape, ))
Exemple #6
0
  def FPropMeta(cls, p, *args):
    py_utils.CheckShapes(args)
    total = 0

    graph_tensors = GraphTensors()
    assert len(p.input_endpoints) == len(args)
    for n, t in zip(p.input_endpoints, args):
      graph_tensors.StoreTensor(n, t)

    ch_out = None
    for signature, sub in p.sub:
      sig = GraphSignature(signature)
      template = py_utils.NestedMap(inputs=sig.inputs)
      packed = template.Transform(graph_tensors.GetTensor)
      input_args = packed.inputs

      meta = sub.cls.FPropMeta(sub, *input_args)
      total += meta.flops
      ch_out = meta.out_shapes
      assert len(ch_out) == len(sig.outputs)
      for n, t in zip(sig.outputs, ch_out):
        graph_tensors.StoreTensor(n, t)

    layer_out = tuple(graph_tensors.GetTensor(x) for x in p.output_endpoints)
    return py_utils.NestedMap(flops=total, out_shapes=layer_out)
Exemple #7
0
  def FPropMeta(cls, p, *args):
    py_utils.CheckShapes(args)
    total = 0
    named_tensors = py_utils.NestedMap()

    assert len(p.input_endpoints) == len(args)
    for n, t in zip(p.input_endpoints, args):
      GraphLayer.AddNamedTensor(p, n, t, named_tensors)

    ch_out = None
    for signature, sub in p.sub:
      i_tensors, o_tensors = GraphLayer.ParseSignature(signature)
      input_args = [
          GraphLayer.GetNamedTensor(p, named_tensors, x) for x in i_tensors
      ]

      meta = sub.cls.FPropMeta(sub, *input_args)
      total += meta.flops
      ch_out = meta.out_shapes
      assert len(ch_out) == len(o_tensors)
      for n, t in zip(o_tensors, ch_out):
        GraphLayer.AddNamedTensor(p, n, t, named_tensors)

    layer_out = tuple(
        GraphLayer.GetNamedTensor(p, named_tensors, x)
        for x in p.output_endpoints)
    return py_utils.NestedMap(flops=total, out_shapes=layer_out)
Exemple #8
0
 def FPropMeta(cls, p, inputs):
   py_utils.CheckShapes((inputs,))
   assert p.input_dims == inputs[-1]
   # c_{ij} += x_{ik} * y_{kj} are considered 2 flops.
   return py_utils.NestedMap(
       flops=inputs.size * p.output_dims * 2,
       out_shapes=(inputs[:-1] + [p.output_dims],))
Exemple #9
0
 def FPropMeta(cls, p, inputs, paddings):
     py_utils.CheckShapes((inputs, paddings))
     b, t, f, ic = inputs
     assert f == 1
     oc = p.filter_shape[2] * p.filter_shape[3] * p.weight_tiling_factor
     outputs = tshape.Shape([b, t, f, oc])
     flops = b * t * f * p.filter_shape[0] * ic * oc * 5
     return py_utils.NestedMap(flops=flops, out_shapes=(outputs, paddings))
Exemple #10
0
  def FPropMeta(cls, p, *args):
    py_utils.CheckShapes(args)
    total = 0
    outputs = []
    for sub in p.sub:
      tf.logging.vlog(1, '  par abs fprop %s %s %d %s', sub.name, sub.cls,
                      len(args), str(args))
      meta = sub.cls.FPropMeta(sub, *args)
      py_utils.CheckShapes(meta.out_shapes)
      meta.VLog(
          1, '  par abs fprop {} {} {} {}'.format(sub.name, sub.cls, len(args),
                                                  str(args)))
      total += meta.flops
      outputs.append(meta.out_shapes)

    meta = p.merge_meta(outputs)
    py_utils.CheckShapes(meta.out_shapes)
    meta.flops += total
    return meta
Exemple #11
0
 def FPropMeta(cls, p, *args):
   flops, rets = 0, []
   for x in args:
     if x is None:
       rets.append(None)
     else:
       cost, shape = p.fn_meta(x)
       py_utils.CheckShapes((shape,))
       flops += cost
       rets.append(shape)
   return py_utils.NestedMap(flops=flops, out_shapes=tuple(rets))
Exemple #12
0
 def FPropMeta(cls, p, inputs, *args):
   # TODO(ankurbpn): return accurate estimate of flops.
   py_utils.CheckShapes((inputs,))
   flops_per_element = 2  # Is this correct?
   vocab = p.token_emb.vocab_size
   dim = p.token_emb.embedding_dim
   src_time, source_batch = inputs
   flops = flops_per_element * src_time * source_batch * dim * vocab
   args = args if isinstance(args, tuple) else (args,)
   new_inputs = tshape.Shape([src_time, source_batch, dim])
   new_args = list(args)
   if p.add_tgt_embedding_layer:
     tgt_time, tgt_batch = args[1]
     new_args[1] = tshape.Shape([tgt_time, tgt_batch, dim])
   new_args = tuple(new_args[:7])
   return py_utils.NestedMap(flops=flops, out_shapes=(new_inputs,) + new_args)
Exemple #13
0
def _common_gpipe_transformer_fprop_meta(p, inputs, *args):
    """GPipe FPropMeta function."""
    # TODO(huangyp): return accurate estimate of flops.
    py_utils.CheckShapes((inputs, ))
    flops_per_element = 5
    src_time, source_batch, dim = inputs
    flops = flops_per_element * src_time * src_time * source_batch * dim
    args = args if isinstance(args, tuple) else (args, )
    if not p.has_aux_atten and p.is_transparent:  # Transparent Encoder FPropMeta
        if p.transparent_merger_tpl is not None:
            args = args[:5] + (
                inputs, tshape.Shape([p.transparent_merger_tpl.num_sources]))
        args = args[:6] + (tshape.Shape([args[6][0] - 1]), )
        if p.final_enc_layer:
            args = args[:5] + (None, None)
    return py_utils.NestedMap(flops=flops, out_shapes=(inputs, ) + args)
Exemple #14
0
 def FPropMeta(cls, p, inputs, *args):
     # TODO(huangyp): return accurate estimate of flops.
     py_utils.CheckShapes((inputs, ))
     flops_per_element = 5
     src_time, source_batch, dim = inputs.as_list()
     flops = flops_per_element * src_time * src_time * source_batch * dim
     args = args if isinstance(args, tuple) else (args, )
     if p.is_transparent:
         if p.has_aux_atten:  # Decoder FPropMeta
             args = args[:-1] if len(args) > 5 else args
         else:
             if p.num_transparent_outputs == 0:
                 args += (inputs, )
             else:
                 args += (inputs, ) * (p.num_transparent_outputs -
                                       len(args) + 4)
     return py_utils.NestedMap(flops=flops, out_shapes=(inputs, ) + args)
Exemple #15
0
def _common_gpipe_transformer_fprop_meta(p, inputs, *args):
  """GPipe FPropMeta function."""
  # TODO(huangyp): return accurate estimate of flops.
  py_utils.CheckShapes((inputs,))
  flops_per_element = 5
  src_time, source_batch, dim = inputs
  flops = flops_per_element * src_time * src_time * source_batch * dim
  args = args if isinstance(args, tuple) else (args,)
  if p.is_transparent:
    if p.has_aux_atten:  # Decoder FPropMeta
      args = args[:-1] if len(args) > 7 else args
    else:
      if p.num_transparent_outputs == 0:
        args += (inputs,)
      elif p.num_transparent_outputs == 1:
        # Switch back to non-transparent mode for decoder.
        args = args[:7]
      else:
        args += (inputs,) * (p.num_transparent_outputs - len(args) + 6)
  return py_utils.NestedMap(flops=flops, out_shapes=(inputs,) + args)
Exemple #16
0
 def FPropMeta(cls, p, inputs, *args):
     # TODO(ankurbpn): return accurate estimate of flops.
     py_utils.CheckShapes((inputs, ))
     flops_per_element = 2  # Is this correct?
     vocab = p.token_emb.vocab_size
     dim = p.token_emb.embedding_dim
     src_dim_0, src_dim_1 = inputs
     flops = flops_per_element * src_dim_0 * src_dim_1 * dim * vocab
     args = args if isinstance(args, tuple) else (args, )
     new_inputs = tshape.Shape([src_dim_0, src_dim_1, dim])
     new_args = list(args)
     if p.add_tgt_embedding_layer:
         tgt_dim_0, tgt_dim_1 = args[1]
         new_args[1] = tshape.Shape([tgt_dim_0, tgt_dim_1, dim])
     if p.ret_task_ids:
         new_args = new_args[:5] + [None, None] + new_args[7:]
     else:
         new_args = new_args[:5] + [None, None]
     new_args = tuple(new_args)
     return py_utils.NestedMap(flops=flops,
                               out_shapes=(new_inputs, ) + new_args)
Exemple #17
0
 def FPropMeta(cls, p, inputs):
     """Returns metadata about the `FProp` computation for this layer."""
     py_utils.CheckShapes((inputs, ))
     return py_utils.NestedMap(flops=inputs.num_elements() *
                               _BN_FLOPS_PER_ELEMENT,
                               out_shapes=(inputs, ))
Exemple #18
0
 def FPropMeta(cls, p, inputs, padding=None):
     py_utils.CheckShapes((inputs, ))
     return py_utils.NestedMap(flops=inputs.num_elements() *
                               _BN_FLOPS_PER_ELEMENT,
                               out_shapes=(inputs, ))
 def FPropMeta(cls, p, inputs):
     py_utils.CheckShapes((inputs, ))
     return py_utils.NestedMap(flops=inputs.num_elements() * 5,
                               out_shapes=(inputs, ))
Exemple #20
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   meta = p.body.cls.FPropMeta(p.body, *args)
   py_utils.CheckShapes(meta.out_shapes)
   total = meta.flops * p.repeat
   return py_utils.NestedMap(flops=total, out_shapes=args)
Exemple #21
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   return py_utils.NestedMap(flops=0, out_shapes=args)
Exemple #22
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   return p.body.cls.FPropMeta(p.body, *args)
Exemple #23
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   meta = p.fn_meta(*args)
   py_utils.CheckShapes(meta.out_shapes)
   return meta
Exemple #24
0
 def FPropMeta(cls, p, inputs):
     py_utils.CheckShapes(
         tuple(inputs.Filter(lambda x: x is not None).Flatten()))
     return py_utils.NestedMap(flops=1, out_shapes=(inputs, ))
Exemple #25
0
 def FPropMeta(cls, p, inputs):
     py_utils.CheckShapes((inputs, ))
     flops_per_element = 10  # Approximately 10 flops per element.
     return py_utils.NestedMap(flops=inputs.num_elements() *
                               flops_per_element,
                               out_shapes=(inputs, ))
Exemple #26
0
 def FPropMeta(cls, p, inputs):
     py_utils.CheckShapes((inputs, ))
     return py_utils.NestedMap(flops=1, out_shapes=(inputs, ))
Exemple #27
0
 def FPropMeta(cls, p, inputs):
   py_utils.CheckShapes((inputs,))
   assert inputs[-1] == p.dims
   return py_utils.NestedMap(flops=inputs.size, out_shapes=(inputs,))