예제 #1
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   total = 0
   for _ in range(p.repeat):
     for sub in p.sub:
       tf.logging.vlog(1, '  seq abs fprop %s %s %d %s', sub.name,
                            sub.cls, len(args), str(args))
       meta = sub.cls.FPropMeta(sub, *args)
       py_utils.CheckShapes(meta.out_shapes)
       total += meta.flops
       args = meta.out_shapes
   return py_utils.NestedMap(flops=total, out_shapes=args)
예제 #2
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   input_shapes = [
       None if arg is None else tshape.Shape(arg.get_shape().as_list()[1:])
       for arg in args
   ]
   meta = p.body.cls.FPropMeta(p.body, *input_shapes)
   py_utils.CheckShapes(meta.out_shapes)
   total = meta.flops * p.repeat
   out_shapes = [
       None if s is None else tshape.Shape([p.repeat] + s[:])
       for s in meta.out_shapes
   ]
   return py_utils.NestedMap(flops=total, out_shapes=tuple(out_shapes))
예제 #3
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   if p.n > 1:
     out_shapes = args[:p.n]
   else:
     out_shapes = (args[0],)
   return py_utils.NestedMap(flops=0, out_shapes=out_shapes)
예제 #4
0
  def FPropMeta(cls, p, *args):
    py_utils.CheckShapes(args)
    total = 0

    graph_tensors = GraphTensors()
    assert len(p.input_endpoints) == len(args)
    for n, t in zip(p.input_endpoints, args):
      graph_tensors.StoreTensor(n, t)

    ch_out = None
    for signature, sub in p.sub:
      sig = GraphSignature(signature)
      template = py_utils.NestedMap(inputs=sig.inputs)
      packed = template.Transform(graph_tensors.GetTensor)
      input_args = packed.inputs

      meta = sub.cls.FPropMeta(sub, *input_args)
      total += meta.flops
      ch_out = meta.out_shapes
      assert len(ch_out) == len(sig.outputs)
      for n, t in zip(sig.outputs, ch_out):
        graph_tensors.StoreTensor(n, t)

    layer_out = tuple(graph_tensors.GetTensor(x) for x in p.output_endpoints)
    return py_utils.NestedMap(flops=total, out_shapes=layer_out)
예제 #5
0
 def FPropMeta(cls, p, inputs):
   py_utils.CheckShapes((inputs,))
   assert p.input_dims == inputs[-1]
   # c_{ij} += x_{ik} * y_{kj} are considered 2 flops.
   return py_utils.NestedMap(
       flops=inputs.size * p.output_dims * 2,
       out_shapes=(inputs[:-1] + [p.output_dims],))
예제 #6
0
 def FPropMeta(cls, p, inputs, paddings):
   py_utils.CheckShapes((inputs, paddings))
   b, t, f, _ = inputs
   assert f == 1
   oc = p.filter_shape[2] * p.filter_shape[3] * p.weight_tiling_factor
   outputs = tshape.Shape([b, t, f, oc])
   flops = b * t * f * p.filter_shape[0] * oc * 5
   return py_utils.NestedMap(flops=flops, out_shapes=(outputs, paddings))
예제 #7
0
  def FPropMeta(cls, p, *args):
    py_utils.CheckShapes(args)
    total = 0
    outputs = []
    for sub in p.sub:
      tf.logging.vlog(1, '  par abs fprop %s %s %d %s', sub.name, sub.cls,
                           len(args), str(args))
      meta = sub.cls.FPropMeta(sub, *args)
      py_utils.CheckShapes(meta.out_shapes)
      meta.VLog(
          1, '  par abs fprop {} {} {} {}'.format(sub.name, sub.cls, len(args),
                                                  str(args)))
      total += meta.flops
      outputs.append(meta.out_shapes)

    meta = p.merge_meta(outputs)
    py_utils.CheckShapes(meta.out_shapes)
    meta.flops += total
    return meta
예제 #8
0
 def FPropMeta(cls, p, *args):
   flops, rets = 0, []
   for x in args:
     if x is None:
       rets.append(None)
     else:
       cost, shape = p.fn_meta(x)
       py_utils.CheckShapes((shape,))
       flops += cost
       rets.append(shape)
   return py_utils.NestedMap(flops=flops, out_shapes=tuple(rets))
def _common_gpipe_transformer_fprop_meta(p, inputs, *args):
    """GPipe FPropMeta function."""
    # TODO(huangyp): return accurate estimate of flops.
    py_utils.CheckShapes((inputs, ))
    flops_per_element = 5
    src_time, source_batch, dim = inputs
    flops = flops_per_element * src_time * src_time * source_batch * dim
    args = args if isinstance(args, tuple) else (args, )
    if not p.has_aux_atten and p.is_transparent:  # Transparent Encoder FPropMeta
        if p.transparent_merger_tpl is not None:
            args = args[:5] + (
                inputs, tshape.Shape([p.transparent_merger_tpl.num_sources]))
        args = args[:6] + (tshape.Shape([args[6][0] - 1]), )
        if p.final_enc_layer:
            args = args[:5] + (None, None)
    return py_utils.NestedMap(flops=flops, out_shapes=(inputs, ) + args)
 def FPropMeta(cls, p, inputs, *args):
     # TODO(ankurbpn): return accurate estimate of flops.
     py_utils.CheckShapes((inputs, ))
     flops_per_element = 2  # Is this correct?
     vocab = p.token_emb.vocab_size
     dim = p.token_emb.embedding_dim
     src_dim_0, src_dim_1 = inputs
     flops = flops_per_element * src_dim_0 * src_dim_1 * dim * vocab
     args = args if isinstance(args, tuple) else (args, )
     new_inputs = tshape.Shape([src_dim_0, src_dim_1, dim])
     new_args = list(args)
     if p.add_tgt_embedding_layer:
         tgt_dim_0, tgt_dim_1 = args[1]
         new_args[1] = tshape.Shape([tgt_dim_0, tgt_dim_1, dim])
     if p.ret_task_ids:
         new_args = new_args[:5] + [None, None] + new_args[7:]
     else:
         new_args = new_args[:5] + [None, None]
     new_args = tuple(new_args)
     return py_utils.NestedMap(flops=flops,
                               out_shapes=(new_inputs, ) + new_args)
예제 #11
0
 def FPropMeta(cls, p, inputs):
     """Returns metadata about the `FProp` computation for this layer."""
     py_utils.CheckShapes((inputs, ))
     return py_utils.NestedMap(flops=inputs.num_elements() *
                               _BN_FLOPS_PER_ELEMENT,
                               out_shapes=(inputs, ))
예제 #12
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   meta = p.body.cls.FPropMeta(p.body, *args)
   py_utils.CheckShapes(meta.out_shapes)
   total = meta.flops * p.repeat
   return py_utils.NestedMap(flops=total, out_shapes=args)
예제 #13
0
 def FPropMeta(cls, p, inputs):
   py_utils.CheckShapes((inputs,))
   assert inputs[-1] == p.dims
   return py_utils.NestedMap(flops=inputs.size, out_shapes=(inputs,))
예제 #14
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   meta = p.fn_meta(*args)
   py_utils.CheckShapes(meta.out_shapes)
   return meta
예제 #15
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   return p.body.cls.FPropMeta(p.body, *args)
예제 #16
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   return py_utils.NestedMap(flops=0, out_shapes=args)
예제 #17
0
 def FPropMeta(cls, p, inputs, padding=None):
     py_utils.CheckShapes((inputs, ))
     return py_utils.NestedMap(flops=inputs.num_elements() *
                               _BN_FLOPS_PER_ELEMENT,
                               out_shapes=(inputs, ))