Exemplo n.º 1
0
    def create_param(self, param_name, shape, initializer, tags=None):
        """
        Creates parameter with a given name and initializer.

        If param_name is instance of BlobRefernce - then this blob will be used
        to store parameter (no any logic will affect it's location).

        If param_name is instance of a string type, then the final blob will
        be created in the CurrentNameScope with the respect of all parameter
        sharing logic, i.e. 'resolved_name_scope/param_name'.

        Parameter sharing logic is going to override CurrentNameScope accoring
        to the rules that are specified through ParameterSharing contexts,
        all ParameterSharing contexts are applied recursively until there are no
        extra overrides present, where on each step the best match will be
        applied first.

        The following examples should clarify the way ParameterSharing logic
        works:

        As an example if this function is called with parameter 'w':
        a. Call from some scope 'global_scope' with no Parameter sharing:
          'global_scope/w'
        b. Call from scope 'scope_b', with override {'scope_b': 'scope_a'}:
          'scope_a/w'
        c. Call from scope 'scope_a', with override {'scope_a': ''}:
          'scope_a/w'
        d. Call from scope 'scope_b/shared', with overrides
          {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}:
          'scope_a/w'
        d. Call from scope 'scope_b/unshared', with overrides
          {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}:
          'scope_a/unshared/w'
        """
        # ParameterSharing works only for case when param_name is instance of
        # a string type. If param_name is a BlobReference - no attempt for
        # ParameterSharing will be applied.
        if isinstance(param_name, core.BlobReference):
            param_name = str(param_name)
        elif isinstance(param_name, six.string_types):
            # Parameter name will be equal to current Namescope that got
            # resolved with the respect of parameter sharing of the scopes.
            param_name = parameter_sharing_context.get_parameter_name(
                param_name)
        else:
            raise "Unsupported type for param_name"

        if param_name in self._parameters_info:
            assert self._parameters_info[param_name].shape == shape
            return self._parameters_info[param_name].blob

        param_info = initializer.create_param(
            param_name=core.BlobReference(param_name),
            init_net=self.param_init_net,
            shape=shape,
        )
        optim_context = OptimizerContext.current()
        for tag in self._normalize_tags(tags):
            if optim_context.has_optimizer(tag):
                # param_info will check optimizer has not been set
                param_info.optimizer = optim_context.get_optimizer(tag)
        if not param_info.optimizer and optim_context.has_optimizer(DEFAULT_OPTIM):
            param_info.optimizer = optim_context.get_optimizer(DEFAULT_OPTIM)

        reg_context = RegularizerContext.current()
        param_info.regularizer = reg_context

        self._parameters_info[param_name] = param_info
        # Add param to legacy structs as well, so all other functions for
        # parameters are still working.
        self.AddParameter(param_info.blob, tags)
        return param_info.blob
Exemplo n.º 2
0
    def create_param(self, param_name, shape, initializer, tags=None):
        """
        Creates parameter with a given name and initializer.

        If param_name is instance of BlobRefernce - then this blob will be used
        to store parameter (no any logic will affect it's location).

        If param_name is instance of a string type, then the final blob will
        be created in the CurrentNameScope with the respect of all parameter
        sharing logic, i.e. 'resolved_name_scope/param_name'.

        Parameter sharing logic is going to override CurrentNameScope according
        to the rules that are specified through ParameterSharing contexts,
        all ParameterSharing contexts are applied recursively until there are no
        extra overrides present, where on each step the best match will be
        applied first.

        The following examples should clarify the way ParameterSharing logic
        works:

        As an example if this function is called with parameter 'w':
        a. Call from some scope 'global_scope' with no Parameter sharing:
          'global_scope/w'
        b. Call from scope 'scope_b', with override {'scope_b': 'scope_a'}:
          'scope_a/w'
        c. Call from scope 'scope_a', with override {'scope_a': ''}:
          'scope_a/w'
        d. Call from scope 'scope_b/shared', with overrides
          {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}:
          'scope_a/w'
        d. Call from scope 'scope_b/unshared', with overrides
          {'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}:
          'scope_a/unshared/w'
        """
        # ParameterSharing works only for case when param_name is instance of
        # a string type. If param_name is a BlobReference - no attempt for
        # ParameterSharing will be applied.
        if isinstance(param_name, core.BlobReference):
            param_name = str(param_name)
        elif isinstance(param_name, six.string_types):
            # Parameter name will be equal to current Namescope that got
            # resolved with the respect of parameter sharing of the scopes.
            param_name = parameter_sharing_context.get_parameter_name(
                param_name)
        else:
            raise TypeError("Unsupported type for param_name")

        if param_name in self._parameters_info:
            assert self._parameters_info[param_name].shape == shape
            return self._parameters_info[param_name].blob

        param_info = initializer.create_param(
            param_name=core.BlobReference(param_name),
            init_net=self.param_init_net,
            shape=shape,
        )
        optim_context = OptimizerContext.current()
        for tag in self._normalize_tags(tags):
            if optim_context.has_optimizer(tag):
                # param_info will check optimizer has not been set
                param_info.optimizer = optim_context.get_optimizer(tag)
        if not param_info.optimizer and optim_context.has_optimizer(
                DEFAULT_OPTIM):
            param_info.optimizer = optim_context.get_optimizer(DEFAULT_OPTIM)

        reg_context = RegularizerContext.current()
        param_info.regularizer = reg_context

        self._parameters_info[param_name] = param_info
        # Add param to legacy structs as well, so all other functions for
        # parameters are still working.
        self.AddParameter(param_info.blob, tags)
        return param_info.blob