コード例 #1
0
def _network(proto, default_context, batch_size, all_variables, rng):
    network = Network()
    network.name = proto.name
    # Read Repeat Info
    network.repeat_info = {}
    for r in proto.repeat_info:
        network.repeat_info[r.id] = r.times

    network.variables = OrderedDict()

    if batch_size is None:
        network.batch_size = proto.batch_size
    else:
        network.batch_size = batch_size

    for v in proto.variable:
        for variable_index in itertools.product(*map(
                tuple,
                map(range, [network.repeat_info[id] for id in v.repeat_id]))):
            name = v.name
            for index, i in enumerate(variable_index):
                if ('{' + v.repeat_id[index] + '}' in name):
                    name = name.replace('{' + v.repeat_id[index] + '}',
                                        '[' + str(i) + ']')
                else:
                    name += '_' + v.repeat_id[index] + '[' + str(i) + ']'
            if name in all_variables:
                variable = all_variables[name]
            else:
                shape = tuple(
                    [d if d >= 1 else network.batch_size for d in v.shape.dim])
                variable = _create_variable(v, name, shape, rng)
                all_variables[name] = variable
            network.variables[name] = variable
            logger.debug('{}'.format(
                (name, variable.shape,
                 v.initializer.type if v.initializer.type else '-',
                 v.initializer.multiplier)))

    network.functions = OrderedDict()
    network.function_inputs = OrderedDict()
    network.function_outputs = OrderedDict()
    network.variable_inputs = OrderedDict()
    network.variable_outputs = OrderedDict()
    for f in proto.function:
        ctx = default_context if not f.context.backends else _context(
            f.context)

        for variable_index in itertools.product(*map(
                tuple,
                map(range, [network.repeat_info[id] for id in f.repeat_id]))):
            function, input_variable_names, output_variable_names = _create_function(
                ctx, network, f, variable_index)
            if function is not None:
                network.functions[function.name] = function
                for v_name in output_variable_names:
                    network.variable_inputs[network.variables[v_name]] = [
                        function
                    ]
                for v_name in input_variable_names:
                    if not network.variables[
                            v_name] in network.variable_outputs:
                        network.variable_outputs[
                            network.variables[v_name]] = []
                    network.variable_outputs[network.variables[v_name]].append(
                        function)

    network.setup(optimize=True)
    return network
コード例 #2
0
ファイル: load.py プロジェクト: zwsong/nnabla
def _network(proto, default_context, all_variables):
    network = Network()
    network.name = proto.name
    # Read Repeat Info
    network.repeat_info = {}
    for r in proto.repeat_info:
        network.repeat_info[r.id] = r.times

    network.variables = OrderedDict()
    network.batch_size = proto.batch_size
    for v in proto.variable:
        for variable_index in itertools.product(*map(tuple, map(range, [network.repeat_info[id] for id in v.repeat_id]))):
            name = v.name + ''.join(['_' + v.repeat_id[index] + '[' +
                                     str(i) + ']' for index, i in enumerate(variable_index)])
            if name in all_variables:
                variable = all_variables[name]
            else:
                shape = tuple(
                    [d if d >= 1 else network.batch_size for d in v.shape.dim])
                variable = _create_variable(v, name, shape)
                all_variables[name] = variable
            network.variables[name] = variable
            logger.debug('{}'.format(
                (name, variable.shape, v.initializer.type if v.initializer.type else '-', v.initializer.multiplier)))

    network.functions = OrderedDict()
    network.function_inputs = OrderedDict()
    network.function_outputs = OrderedDict()
    network.variable_inputs = OrderedDict()
    network.variable_outputs = OrderedDict()
    for f in proto.function:
        ctx = default_context if f.context.backend == "" else context(
            f.context)

        for variable_index in itertools.product(*map(tuple, map(range, [network.repeat_info[id] for id in f.repeat_id]))):
            function, input_variable_names, output_variable_names = _create_function(
                ctx, network, f, variable_index)
            if function is not None:
                network.functions[function.name] = function
                for v_name in output_variable_names:
                    network.variable_inputs[
                        network.variables[v_name]] = [function]
                for v_name in input_variable_names:
                    if not network.variables[v_name] in network.variable_outputs:
                        network.variable_outputs[
                            network.variables[v_name]] = []
                    network.variable_outputs[
                        network.variables[v_name]].append(function)

    network.setup(optimize=True)
    return network