示例#1
0
    def __init__(self, proto_network, batch_size, callback):
        proto_network = proto_network.expand_loop_control()
        self.proto_network = proto_network.promote(callback)
        self.proto_network(batch_size=batch_size)
        for k, v in itertools.chain(
                self.proto_network.variables.items(), self.proto_network.parameters.items()):
            v.variable_instance.name = k
        self._inputs = {
            i: self.proto_network.variables[i].variable_instance
            for i in self.proto_network.inputs
        }
        self._outputs = {
            i: self.proto_network.variables[i].variable_instance
            for i in self.proto_network.outputs
        }
        self._variables = {
            k: v.variable_instance
            for k, v in itertools.chain(
                self.proto_network.variables.items(), self.proto_network.parameters.items())
        }

        # publish network's parameters to current parameter scope
        # like original implementation.
        with nn.parameter_scope('', nn.get_current_parameter_scope()):
            for k, v in self.proto_network.parameters.items():
                nn.parameter.set_parameter(k, v.variable_instance)
示例#2
0
    def __init__(self,
                 network_proto,
                 scope,
                 batch_size=None,
                 rng=None,
                 callback=None):

        if batch_size is None:
            batch_size = network_proto.batch_size
        self.batch_size = batch_size
        if rng is None:
            rng = np.random.RandomState(1223)
        self.rng = rng

        if callback is None:
            callback = NnpNetworkPass()  # No pass

        # Variable proto messages as a dictionary with name as a key
        variables = {v.name: VariableProto(v) for v in network_proto.variable}
        functions = [FunctionProto(f) for f in network_proto.function]

        for f in functions:
            inputs = [variables[name] for name in f.proto.input]
            outputs = [variables[name] for name in f.proto.output]
            f.inputs = inputs
            f.outputs = outputs

        # Apply function passes
        for f in self._functions_in_forward_order(variables):
            if f.disabled:
                continue
            callback._apply_function_pass_by_type(f, variables, scope)
            callback._apply_function_pass_by_name(f, variables, scope)

        # Apply stop-at.
        for f in self._functions_in_forward_order(variables):
            # callback.verbose2('Applying stop-at for inputs of {}.'.format(f.name))
            callback._apply_use_up_to(f.inputs)

        # Build computation graph
        num_ops = 0
        current_scope = nn.get_current_parameter_scope()
        with nn.parameter_scope('', scope):
            for f in self._functions_in_forward_order(variables):
                self._create_function(f, callback, current_scope)
                # print(f.name)
                num_ops += 1
        callback.verbose2('Created {} functions.'.format(num_ops))

        variables = self._filter_variables(variables)
        inputs = self._get_inputs(variables)
        outputs = self._get_outputs(variables)

        # Get input variables
        self.variables = {v.name: v.variable for v in variables.values()}
        self.inputs = {i.name: i.variable for i in inputs}
        self.outputs = {o.name: o.variable for o in outputs}