def train(self, train, validation, num_epochs=None, learning_rate=0.01, threshold=0.001): """Train the FFNN with gradient descent. Dynamic stopping on lowest validation error. Training runs over the given number of epochs. If None are given, then training runs until the threshold (change in validation error) is reached over multiple consecutive iterations. This dynamic stopping also occurs if validation error begins to increase. When dynamic stopping is used, the network finalizes the best weights found of the duration of training. """ num_epochs_iter = num_epochs if num_epochs else 600 # 600 set to max epochs dynamic_stopping = False if num_epochs else True # Dynamically halt if num_epochs unspec. retries = 0 err = self.evaluate(validation) progress_bar = ProgressBar() for epoch in range(num_epochs_iter): last_err = err for i in range(len(train)): progress_bar.refresh(i / len(train)) sample = choice(train) # Randomly sample training data # Update weights based on the chosen sample self.prepare_network() self.propagate_input(sample.features) self.propagate_error(sample.label) self.update_weights(sample, learning_rate, momentum=0.3) progress_bar.refresh(1.0) progress_bar.clear() # Evaluate validation error err = self.evaluate(validation) print('Epoch {} validation error: {:.4f}'.format(epoch, err)) if dynamic_stopping: if last_err - err < threshold: if err <= last_err: # Still improved, but below threshold self.save_network_weights(err) retries += 1 if retries >= 100: epochs_ran = epoch break else: self.save_network_weights(err) retries = 0 else: epochs_ran = num_epochs_iter # Loop did not stop early if dynamic_stopping: self.finalize_network_weights( ) # Finalize weights to best validation error return epochs_ran
def evaluate(self, samples): """Evaluate a set of samples using RMSE.""" ssq_error = 0 progress_bar = ProgressBar() for i, sample in enumerate(samples): progress_bar.refresh(i / len(samples)) ssq_error += self.sq_error(sample) progress_bar.refresh(1.0) progress_bar.clear() return sqrt(ssq_error / len(samples))
def eval_classification(self, samples, verbose=False): """Returns the classification error as a decimal for a list of samples.""" sc_error = 0 progress_bar = ProgressBar() for i, sample in enumerate(samples): progress_bar.refresh(i / len(samples)) sc_error += self.classification_error(sample, verbose) progress_bar.refresh(1.0) progress_bar.clear() return sc_error / len( samples) # Percent of samples classified incorrectly
def train_worker(self, train, learning_rate=0.01, batch_size=64): """Get training weight updates as a worker node (distributed training).""" network_weight_updates = deque([None]) # None is a sentinel for init progress_bar = ProgressBar() for i in range(batch_size): progress_bar.refresh(i / batch_size) sample = choice(train) # Randomly sample training data # Update weights self.prepare_network() self.propagate_input(sample.features) self.propagate_error(sample.label) self.get_weight_updates(sample, learning_rate, network_weight_updates) progress_bar.refresh(1.0) progress_bar.clear() print('Computed weight updates for network on batch_size of {}'.format( batch_size)) return network_weight_updates