Exemple #1
0
    def train(self, network, data_set, name_metric_function, metric_value_min, nb_iterations_max, interval_check=100):
        """
        Apply the training process on a `DataSet`, until the metric
        value computed using metric_function *equals or is greater than*
        metric_value_min.

        :Returns:
            integer : the number of iterations in the learning.

        :Raises NpyValueError:
            If interval_check is lower than 1.

        :Raises NpyTransferFunctionError:
            If name_metric_function does not correspond to a metric function.
        """
        
        if interval_check < 1:
            raise NpyValueError, 'interval_check has to be greater or equal to 1'

        Factory.check_prefix(name_metric_function, Metric.prefix)
        metric_function = Factory.build_instance_by_name(name_metric_function)

        nb_iterations_current = 0
        metric_value_computed = metric_value_min - 1
        while nb_iterations_current < nb_iterations_max and metric_value_computed < metric_value_min:
            network.learn_cycles(data_set, interval_check)
            data_classification = network.classify_data_set(data_set)
            metric_value_computed = metric_function.compute_metric(data_set, data_classification)
            nb_iterations_current += interval_check
            
        return nb_iterations_current
Exemple #2
0
 def set_label_function(self, name_label_function):
     """
     :Raises NpyTransferFunctionError:
         If name_label_function does not correspond to a label function.
     """
     try:
         Factory.check_prefix(name_label_function, Label.prefix)
         self._label_function = Factory.build_instance_by_name(name_label_function)
     except NpyTransferFunctionError, e:
         raise NpyTransferFunctionError, e.msg
Exemple #3
0
    def train_network(self, network, data_set, name_metric_function, metric_value_min, nb_iterations_max, interval_check):
        """
        Apply the training process on a `DataSet`, until the `Metric`
        value computed using metric_function *equals or is greater than*
        metric_value_min. This makes the assumption that the metric
        functions gives higher values for higher network performances.
        The training is stopped after nb_iterations_max to avoid infinite
        loops due to unreachable `Metric` values.
        """
        
        if interval_check < 1:
            raise NpyValueError, 'interval_check has to be greater or equal to 1.'

        if nb_iterations_max != None and nb_iterations_max < 1:
            raise NpyValueError, 'nb_iterations_max has to be greater or equal to 1, or equal to None.'

        try:
            Factory.check_prefix(name_metric_function, Metric.prefix)
            metric_function = Factory.build_instance_by_name(name_metric_function)
        except NpyTransferFunctionError, e:
            raise NpyTransferFunctionError, e.msg
Exemple #4
0
    def add_unit(self, nb_nodes, name_activation_function=None, name_update_function=None, name_error_function=None):
        """
        Adds a unit to the network as the new output unit. Takes care of
        making the connections with the previous unit.

        :Parameters:
            nb_nodes : integer
                Number of nodes required in the unit. 
            name_activation_function : string
                Name of the `Activation` to use to compute the activation
                function for the current unit.
            name_update_function : string
                Name of the `Update` to use to compute the updates to
                the weights.
            name_error_function : string
                Name of the `Error` to use to compute the error of the `Unit`.
                If equal to None, then the error function is set
                automatically, depending on the unit position in the network.

        :Returns:
            The `Unit` that has just been added to the network. In the case
            of the input unit, None is returned.

        :Raises NpyTransferFunctionError:
            If the function names do not correspond to valid functions.

        :Raises NpyUnitError:
            If an error related to the unit topology is encountered.
        """
        # A positive number of nodes is required
        if nb_nodes <= 0:
            raise NpyUnitError, 'Number of nodes must be strictly positive.'

        # And for the non-input units, the activation and update functions
        # must be defined.
        if self.unit_input != None \
          and (name_activation_function == None or name_update_function == None):
            raise NpyUnitError, 'Activation and update functions must be specified.'

        # Handle the input unit
        if self.unit_input == None:
            unit = UnitInput(nb_nodes)
            self.unit_input = unit
        else:
            # Handle the other units
            if len(self.units) == 0:
                unit_previous = self.unit_input
            else:
                unit_previous = self.units[-1]

            if self.use_bias == True:
                # Add 1 in order to implement the bias
                nb_previous_nodes = unit_previous.get_nb_nodes() + 1

            # Retreive transfert function instances
            try:
                Factory.check_prefix(name_activation_function, Activation.prefix)
                activation_function = Factory.build_instance_by_name(name_activation_function)

                Factory.check_prefix(name_update_function, Update.prefix)
                update_function = Factory.build_instance_by_name(name_update_function)

                if name_error_function == None:
                    error_function = None
                else:
                    Factory.check_prefix(name_error_function, Error.prefix)
                    error_function = Factory.build_instance_by_name(name_error_function)
            except NpyTransferFunctionError, e:
                raise NpyTransferFunctionError, e.msg

            # Create the unit and add it to the network
            unit = Unit(nb_nodes, nb_previous_nodes, activation_function, update_function, error_function)
            self.units.append(unit)