def tuple(op_name, input_layers, **kwargs): # type: (str, int, List[XLayer]) -> XLayer """ Create an tuple XLayer for grouping a list of input layers Arguments --------- input_layers: List[XLayer] The input layers to be grouped in a tuple data structure """ bottoms = [input_layer.name for input_layer in input_layers] shapes = TupleShape([TensorShape(il.shapes[:]) for il in input_layers]) X = XLayer() X = X._replace(name=op_name, type=['Tuple'], shapes=shapes, sizes=shapes.get_size(), layer=[op_name], tops=[], bottoms=bottoms, attrs=kwargs, targets=[]) return X
def test_get_size(self): ts = TupleShape([TensorShape([-1, 2]), TensorShape([-1, 2, 4])]) assert ts.get_size() == [2, 8]