def FPropMeta(cls, p, *args): py_utils.CheckShapes(args) input_shapes = [ None if arg is None else tshape.Shape(arg.get_shape().as_list()[1:]) for arg in args ] meta = p.body.cls.FPropMeta(p.body, *input_shapes) py_utils.CheckShapes(meta.out_shapes) total = meta.flops * p.repeat out_shapes = [ None if s is None else tshape.Shape([p.repeat] + s[:]) for s in meta.out_shapes ] return py_utils.NestedMap(flops=total, out_shapes=tuple(out_shapes))
def _common_gpipe_transformer_fprop_meta(p, inputs, *args): """GPipe FPropMeta function.""" # TODO(huangyp): return accurate estimate of flops. py_utils.CheckShapes((inputs, )) flops_per_element = 5 src_time, source_batch, dim = inputs flops = flops_per_element * src_time * src_time * source_batch * dim args = args if isinstance(args, tuple) else (args, ) if not p.has_aux_atten and p.is_transparent: # Transparent Encoder FPropMeta if p.transparent_merger_tpl is not None: args = args[:5] + ( inputs, tshape.Shape([p.transparent_merger_tpl.num_sources])) args = args[:6] + (tshape.Shape([args[6][0] - 1]), ) if p.final_enc_layer: args = args[:5] + (None, None) return py_utils.NestedMap(flops=flops, out_shapes=(inputs, ) + args)
def _InferOutShapes(self, args): input_shapes = [ None if arg is None else tshape.Shape(arg.get_shape().as_list()[1:]) for arg in args ] out_shapes = self.body.FPropMeta(self.body.params, *input_shapes).out_shapes return [None if s is None else s.ToTensorShape() for s in out_shapes]
def FPropMeta(cls, p, inputs, paddings): py_utils.CheckShapes((inputs, paddings)) b, t, f, _ = inputs assert f == 1 oc = p.filter_shape[2] * p.filter_shape[3] * p.weight_tiling_factor outputs = tshape.Shape([b, t, f, oc]) flops = b * t * f * p.filter_shape[0] * oc * 5 return py_utils.NestedMap(flops=flops, out_shapes=(outputs, paddings))
def FPropMeta(cls, p, inputs, *args): # TODO(ankurbpn): return accurate estimate of flops. py_utils.CheckShapes((inputs, )) flops_per_element = 2 # Is this correct? vocab = p.token_emb.vocab_size dim = p.token_emb.embedding_dim src_dim_0, src_dim_1 = inputs flops = flops_per_element * src_dim_0 * src_dim_1 * dim * vocab args = args if isinstance(args, tuple) else (args, ) new_inputs = tshape.Shape([src_dim_0, src_dim_1, dim]) new_args = list(args) if p.add_tgt_embedding_layer: tgt_dim_0, tgt_dim_1 = args[1] new_args[1] = tshape.Shape([tgt_dim_0, tgt_dim_1, dim]) if p.ret_task_ids: new_args = new_args[:5] + [None, None] + new_args[7:] else: new_args = new_args[:5] + [None, None] new_args = tuple(new_args) return py_utils.NestedMap(flops=flops, out_shapes=(new_inputs, ) + new_args)
def FPropMeta(cls, p, inputs, *args): dim1, dim2 = args[1][:2] if p.inputs_from_decoder else inputs[:2] logits = tshape.Shape([dim1, dim2, p.num_classes]) return py_utils.NestedMap(flops=100, out_shapes=(logits, ))
def _ToTShape(x): if x is None: return None return tshape.Shape(x.as_list())