예제 #1
0
    def train(self,
              train,
              validation,
              num_epochs=None,
              learning_rate=0.01,
              threshold=0.001):
        """Train the FFNN with gradient descent. Dynamic stopping on lowest validation error.

        Training runs over the given number of epochs. If None are given, then training runs
        until the threshold (change in validation error) is reached over multiple consecutive
        iterations. This dynamic stopping also occurs if validation error begins to increase.
        When dynamic stopping is used, the network finalizes the best weights found of the
        duration of training.
        """

        num_epochs_iter = num_epochs if num_epochs else 600  # 600 set to max epochs
        dynamic_stopping = False if num_epochs else True  # Dynamically halt if num_epochs unspec.
        retries = 0
        err = self.evaluate(validation)

        progress_bar = ProgressBar()
        for epoch in range(num_epochs_iter):
            last_err = err
            for i in range(len(train)):
                progress_bar.refresh(i / len(train))
                sample = choice(train)  # Randomly sample training data

                # Update weights based on the chosen sample
                self.prepare_network()
                self.propagate_input(sample.features)
                self.propagate_error(sample.label)
                self.update_weights(sample, learning_rate, momentum=0.3)

            progress_bar.refresh(1.0)
            progress_bar.clear()

            # Evaluate validation error
            err = self.evaluate(validation)
            print('Epoch {} validation error: {:.4f}'.format(epoch, err))
            if dynamic_stopping:
                if last_err - err < threshold:
                    if err <= last_err:  # Still improved, but below threshold
                        self.save_network_weights(err)

                    retries += 1
                    if retries >= 100:
                        epochs_ran = epoch
                        break
                else:
                    self.save_network_weights(err)
                    retries = 0
        else:
            epochs_ran = num_epochs_iter  # Loop did not stop early

        if dynamic_stopping:
            self.finalize_network_weights(
            )  # Finalize weights to best validation error

        return epochs_ran
예제 #2
0
    def evaluate(self, samples):
        """Evaluate a set of samples using RMSE."""
        ssq_error = 0
        progress_bar = ProgressBar()
        for i, sample in enumerate(samples):
            progress_bar.refresh(i / len(samples))
            ssq_error += self.sq_error(sample)
        progress_bar.refresh(1.0)
        progress_bar.clear()

        return sqrt(ssq_error / len(samples))
예제 #3
0
    def eval_classification(self, samples, verbose=False):
        """Returns the classification error as a decimal for a list of samples."""
        sc_error = 0
        progress_bar = ProgressBar()
        for i, sample in enumerate(samples):
            progress_bar.refresh(i / len(samples))
            sc_error += self.classification_error(sample, verbose)
        progress_bar.refresh(1.0)
        progress_bar.clear()

        return sc_error / len(
            samples)  # Percent of samples classified incorrectly
예제 #4
0
    def train_worker(self, train, learning_rate=0.01, batch_size=64):
        """Get training weight updates as a worker node (distributed training)."""

        network_weight_updates = deque([None])  # None is a sentinel for init
        progress_bar = ProgressBar()
        for i in range(batch_size):
            progress_bar.refresh(i / batch_size)
            sample = choice(train)  # Randomly sample training data

            # Update weights
            self.prepare_network()
            self.propagate_input(sample.features)
            self.propagate_error(sample.label)
            self.get_weight_updates(sample, learning_rate,
                                    network_weight_updates)

        progress_bar.refresh(1.0)
        progress_bar.clear()
        print('Computed weight updates for network on batch_size of {}'.format(
            batch_size))

        return network_weight_updates