def _network(proto, default_context, batch_size, all_variables, rng): network = Network() network.name = proto.name # Read Repeat Info network.repeat_info = {} for r in proto.repeat_info: network.repeat_info[r.id] = r.times network.variables = OrderedDict() if batch_size is None: network.batch_size = proto.batch_size else: network.batch_size = batch_size for v in proto.variable: for variable_index in itertools.product(*map( tuple, map(range, [network.repeat_info[id] for id in v.repeat_id]))): name = v.name for index, i in enumerate(variable_index): if ('{' + v.repeat_id[index] + '}' in name): name = name.replace('{' + v.repeat_id[index] + '}', '[' + str(i) + ']') else: name += '_' + v.repeat_id[index] + '[' + str(i) + ']' if name in all_variables: variable = all_variables[name] else: shape = tuple( [d if d >= 1 else network.batch_size for d in v.shape.dim]) variable = _create_variable(v, name, shape, rng) all_variables[name] = variable network.variables[name] = variable logger.debug('{}'.format( (name, variable.shape, v.initializer.type if v.initializer.type else '-', v.initializer.multiplier))) network.functions = OrderedDict() network.function_inputs = OrderedDict() network.function_outputs = OrderedDict() network.variable_inputs = OrderedDict() network.variable_outputs = OrderedDict() for f in proto.function: ctx = default_context if not f.context.backends else _context( f.context) for variable_index in itertools.product(*map( tuple, map(range, [network.repeat_info[id] for id in f.repeat_id]))): function, input_variable_names, output_variable_names = _create_function( ctx, network, f, variable_index) if function is not None: network.functions[function.name] = function for v_name in output_variable_names: network.variable_inputs[network.variables[v_name]] = [ function ] for v_name in input_variable_names: if not network.variables[ v_name] in network.variable_outputs: network.variable_outputs[ network.variables[v_name]] = [] network.variable_outputs[network.variables[v_name]].append( function) network.setup(optimize=True) return network
def _network(proto, default_context, all_variables): network = Network() network.name = proto.name # Read Repeat Info network.repeat_info = {} for r in proto.repeat_info: network.repeat_info[r.id] = r.times network.variables = OrderedDict() network.batch_size = proto.batch_size for v in proto.variable: for variable_index in itertools.product(*map(tuple, map(range, [network.repeat_info[id] for id in v.repeat_id]))): name = v.name + ''.join(['_' + v.repeat_id[index] + '[' + str(i) + ']' for index, i in enumerate(variable_index)]) if name in all_variables: variable = all_variables[name] else: shape = tuple( [d if d >= 1 else network.batch_size for d in v.shape.dim]) variable = _create_variable(v, name, shape) all_variables[name] = variable network.variables[name] = variable logger.debug('{}'.format( (name, variable.shape, v.initializer.type if v.initializer.type else '-', v.initializer.multiplier))) network.functions = OrderedDict() network.function_inputs = OrderedDict() network.function_outputs = OrderedDict() network.variable_inputs = OrderedDict() network.variable_outputs = OrderedDict() for f in proto.function: ctx = default_context if f.context.backend == "" else context( f.context) for variable_index in itertools.product(*map(tuple, map(range, [network.repeat_info[id] for id in f.repeat_id]))): function, input_variable_names, output_variable_names = _create_function( ctx, network, f, variable_index) if function is not None: network.functions[function.name] = function for v_name in output_variable_names: network.variable_inputs[ network.variables[v_name]] = [function] for v_name in input_variable_names: if not network.variables[v_name] in network.variable_outputs: network.variable_outputs[ network.variables[v_name]] = [] network.variable_outputs[ network.variables[v_name]].append(function) network.setup(optimize=True) return network