Ejemplo n.º 1
0
def Placeholder(_inf, name='placeholder'):
    p = placeholder_variable(shape=_as_tuple(_inf.shape),
                             dynamic_axes=_inf.axis,
                             name=name)
    _name_node(p, name)
    if _trace_layers:
        print("new " + _node_description(p))
    return p
Ejemplo n.º 2
0
def Placeholder(shape=None, name='placeholder'):
    if shape is not None:
        p = placeholder_variable(shape=shape, name=name) # TODO: use (*args, **kwargs)?
    else:
        p = placeholder_variable(name=name) # TODO: use (*args, **kwargs)?
    _name_node(p, name)
    if _trace_layers:
        print("new " + _node_description(p))
    return p
Ejemplo n.º 3
0
def _apply(f, args):
    import operator  # add()
    import functools  # reduce()
    from cntk.ops.functions import CloneMethod

    # flatten args to a list. Note it may be a a tuple or even a nested tree of tuples, e.g. LSTM (x, (h, c))
    def flatten_tuple(args):
        if not isinstance(
                args,
                tuple):  # not a tuple: singleton; create a singleton tuple
            return (args, )
        return functools.reduce(operator.add,
                                [(flatten_tuple(item)) for item in args])

    args = list(flatten_tuple(args))

    # TODO: This should go into Function.replace_placeholders()
    def _output_of(
        arg
    ):  # helper to get the output of an arg; use arg itself if no output() method (that'd be a Variable)
        try:
            return arg.output
        except AttributeError:
            return arg  # Variables have no output()

    args = [_output_of(arg) for arg in args]
    placeholders = f.placeholders  # f parameters to fill in
    if len(args) != len(placeholders):
        raise TypeError(
            "_apply ({}): number of arguments {} must match number of placeholders {}"
            .format(_node_description(f), len(args), len(placeholders)))
    _function_name = _node_name(f)  # these are for logging/debugging only
    _function_description = _node_description(f)
    _arg_description = ", ".join([_node_name(f) for f in list(args)])
    f = f.clone(CloneMethod.share, dict(zip(f.placeholders, args)))
    _name_and_extend_Function(f, _function_name)
    if _trace_layers:
        print("{} = {} ({})".format(_node_description(f),
                                    _function_description, _arg_description))
    return f
Ejemplo n.º 4
0
def _extend_Function(f):
    class FunctionEx(f.__class__):
        def __call__(self, *args):
            return _apply(self, _as_tuple(args))

        def __rshift__(self, other):
            return other(self)

        def _name(self):  # retrieve the debug name
            return _node_name(self)

    if hasattr(f, '__call__'):  # already extended: don't do it again
        return f
    f.__class__ = FunctionEx
    if _trace_layers:
        print("def {}".format(_node_description(f)))
    return f