Exemplo n.º 1
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   input_shapes = [
       None if arg is None else tshape.Shape(arg.get_shape().as_list()[1:])
       for arg in args
   ]
   meta = p.body.cls.FPropMeta(p.body, *input_shapes)
   py_utils.CheckShapes(meta.out_shapes)
   total = meta.flops * p.repeat
   out_shapes = [
       None if s is None else tshape.Shape([p.repeat] + s[:])
       for s in meta.out_shapes
   ]
   return py_utils.NestedMap(flops=total, out_shapes=tuple(out_shapes))
def _common_gpipe_transformer_fprop_meta(p, inputs, *args):
    """GPipe FPropMeta function."""
    # TODO(huangyp): return accurate estimate of flops.
    py_utils.CheckShapes((inputs, ))
    flops_per_element = 5
    src_time, source_batch, dim = inputs
    flops = flops_per_element * src_time * src_time * source_batch * dim
    args = args if isinstance(args, tuple) else (args, )
    if not p.has_aux_atten and p.is_transparent:  # Transparent Encoder FPropMeta
        if p.transparent_merger_tpl is not None:
            args = args[:5] + (
                inputs, tshape.Shape([p.transparent_merger_tpl.num_sources]))
        args = args[:6] + (tshape.Shape([args[6][0] - 1]), )
        if p.final_enc_layer:
            args = args[:5] + (None, None)
    return py_utils.NestedMap(flops=flops, out_shapes=(inputs, ) + args)
Exemplo n.º 3
0
 def _InferOutShapes(self, args):
   input_shapes = [
       None if arg is None else tshape.Shape(arg.get_shape().as_list()[1:])
       for arg in args
   ]
   out_shapes = self.body.FPropMeta(self.body.params, *input_shapes).out_shapes
   return [None if s is None else s.ToTensorShape() for s in out_shapes]
Exemplo n.º 4
0
 def FPropMeta(cls, p, inputs, paddings):
   py_utils.CheckShapes((inputs, paddings))
   b, t, f, _ = inputs
   assert f == 1
   oc = p.filter_shape[2] * p.filter_shape[3] * p.weight_tiling_factor
   outputs = tshape.Shape([b, t, f, oc])
   flops = b * t * f * p.filter_shape[0] * oc * 5
   return py_utils.NestedMap(flops=flops, out_shapes=(outputs, paddings))
 def FPropMeta(cls, p, inputs, *args):
     # TODO(ankurbpn): return accurate estimate of flops.
     py_utils.CheckShapes((inputs, ))
     flops_per_element = 2  # Is this correct?
     vocab = p.token_emb.vocab_size
     dim = p.token_emb.embedding_dim
     src_dim_0, src_dim_1 = inputs
     flops = flops_per_element * src_dim_0 * src_dim_1 * dim * vocab
     args = args if isinstance(args, tuple) else (args, )
     new_inputs = tshape.Shape([src_dim_0, src_dim_1, dim])
     new_args = list(args)
     if p.add_tgt_embedding_layer:
         tgt_dim_0, tgt_dim_1 = args[1]
         new_args[1] = tshape.Shape([tgt_dim_0, tgt_dim_1, dim])
     if p.ret_task_ids:
         new_args = new_args[:5] + [None, None] + new_args[7:]
     else:
         new_args = new_args[:5] + [None, None]
     new_args = tuple(new_args)
     return py_utils.NestedMap(flops=flops,
                               out_shapes=(new_inputs, ) + new_args)
 def FPropMeta(cls, p, inputs, *args):
     dim1, dim2 = args[1][:2] if p.inputs_from_decoder else inputs[:2]
     logits = tshape.Shape([dim1, dim2, p.num_classes])
     return py_utils.NestedMap(flops=100, out_shapes=(logits, ))
Exemplo n.º 7
0
 def _ToTShape(x):
     if x is None:
         return None
     return tshape.Shape(x.as_list())