Exemplo n.º 1
0
def tuple(op_name, input_layers, **kwargs):
    # type: (str, int, List[XLayer]) -> XLayer
    """
    Create an tuple XLayer for grouping a list of input layers

    Arguments
    ---------
    input_layers: List[XLayer]
        The input layers to be grouped in a tuple data structure
    """
    bottoms = [input_layer.name for input_layer in input_layers]
    shapes = TupleShape([TensorShape(il.shapes[:]) for il in input_layers])

    X = XLayer()
    X = X._replace(name=op_name,
                   type=['Tuple'],
                   shapes=shapes,
                   sizes=shapes.get_size(),
                   layer=[op_name],
                   tops=[],
                   bottoms=bottoms,
                   attrs=kwargs,
                   targets=[])

    return X
Exemplo n.º 2
0
    def test_get_size(self):

        ts = TupleShape([TensorShape([-1, 2]), TensorShape([-1, 2, 4])])

        assert ts.get_size() == [2, 8]