class Trainer(object): def __init__(self, net, data_pipeline, experiment_id, mongo_host=None, mongo_db='neuralnilm', loss_func=squared_error, loss_aggregation_mode='mean', updates_func=nesterov_momentum, updates_func_kwards=None, learning_rates=None, callbacks=None, repeat_callbacks=None, epoch_callbacks=None, metrics=None, num_seqs_to_plot=8): """ Parameters ---------- experiment_id : list of strings concatenated together with an underscore. Defines output path mongo_host : Address of PyMongo database. See http://docs.mongodb.org/manual/reference/connection-string/ callbacks : list of 2-tuples (<iteration>, <function>) Function must accept a single argument: this Trainer object. repeat_callbacks : list of 2-tuples (<iteration>, <function>) Function must accept a single argument: this Trainer object. For example, to run validation every 100 iterations, set `repeat_callbacks=[(100, Trainer.validate)]`. epoch_callbacks : list of functions Functions are called at the end of each training epoch. metrics : neuralnilm.Metrics object Run during `Trainer.validation()` """ # Database if mongo_host is None: mongo_host = config.get("MongoDB", "address") mongo_client = MongoClient(mongo_host) self.db = mongo_client[mongo_db] # Training and validation state self.requested_learning_rates = ({ 0: 1E-2 } if learning_rates is None else learning_rates) self.experiment_id = "_".join(experiment_id) self._train_func = None self.metrics = metrics self.net = net self.data_pipeline = data_pipeline self.min_train_cost = float("inf") self.num_seqs_to_plot = num_seqs_to_plot # Check if this experiment already exists in database delete_or_quit = None if self.db.experiments.find_one({'_id': self.experiment_id}): delete_or_quit = input( "Database already has an experiment with _id == {}." " Should the old experiment be deleted (d)" " (both from the database and from disk)? Or quit (q)?" " Or append (a) '_try<i>` string to _id? [A/q/d] ".format( self.experiment_id)).lower() if delete_or_quit == 'd': logger.info("Deleting documents for old experiment.") self.db.experiments.delete_one({'_id': self.experiment_id}) for collection in COLLECTIONS: self.db[collection].delete_many( {'experiment_id': self.experiment_id}) elif delete_or_quit == 'q': raise KeyboardInterrupt() else: self.modification_since_last_try = input( "Enter a short description of what has changed since" " the last try: ") try_i = 2 while True: candidate_id = self.experiment_id + '_try' + str(try_i) if self.db.experiments.find_one({'_id': candidate_id}): try_i += 1 else: self.experiment_id = candidate_id logger.info("experiment_id set to {}".format( self.experiment_id)) break # Output path path_list = [config.get('Paths', 'output')] path_list += self.experiment_id.split('_') self.output_path = os.path.join(*path_list) try: os.makedirs(self.output_path) except OSError as os_error: if os_error.errno == 17: # file exists logger.info("Directory exists = '{}'".format(self.output_path)) if delete_or_quit == 'd': logger.info(" Deleting directory.") shutil.rmtree(self.output_path) os.makedirs(self.output_path) else: logger.info(" Re-using directory.") else: raise # Loss and updates def aggregated_loss_func(prediction, target, weights=None): loss = loss_func(prediction, target) return aggregate(loss, mode=loss_aggregation_mode, weights=weights) self.loss_func_name = loss_func.__name__ self.loss_func = aggregated_loss_func self.updates_func_name = updates_func.__name__ self.updates_func_kwards = none_to_dict(updates_func_kwards) self.updates_func = partial(updates_func, **self.updates_func_kwards) self.loss_aggregation_mode = loss_aggregation_mode # Learning rate # Set _learning_rate to -1 so when we set self.learning_rate # during the training loop, it will be logger correctly. self._learning_rate = theano.shared(sfloatX(-1), name='learning_rate') # Callbacks def callbacks_dataframe(lst): return pd.DataFrame(lst, columns=['iteration', 'function']) self.callbacks = callbacks_dataframe(callbacks) self.repeat_callbacks = callbacks_dataframe(repeat_callbacks) self.epoch_callbacks = none_to_list(epoch_callbacks) def submit_report(self, additional_report_contents=None): """Submit report to database. Parameters ---------- additional_report_contents : list of tuples of (list of keys, dict) e.g. >>> contents = [(['data'], {'activations': LOADER_CONFIG})] >>> trainer.submit_report(additional_report_contents=contents) """ report = self.report() if additional_report_contents is not None: for (keys, update) in additional_report_contents: dict_to_update = report for key in keys: dict_to_update = dict_to_update[key] dict_to_update.update(update) report = sanitise_dict_for_mongo(report) self.db.experiments.insert_one(report) return report @property def learning_rate(self): return self._learning_rate.get_value().flatten()[0] @learning_rate.setter def learning_rate(self, rate): rate = sfloatX(rate) if rate == self.learning_rate: logger.info( "Iteration {:d}: Attempted to change learning rate to {:.1E}" " but that is already the value!".format( self.net.train_iterations, rate)) else: logger.info( "Iteration {:d}: Change learning rate to {:.1E}".format( self.net.train_iterations, rate)) self.db.experiments.find_one_and_update( filter={'_id': self.experiment_id}, update={ '$set': { 'trainer.actual_learning_rates.{:d}'.format(self.net.train_iterations): float(rate) } }, upsert=True) self._learning_rate.set_value(rate) def _start_data_thread(self): self.data_thread = DataThread(self.data_pipeline) self.data_thread.start() def fit(self, num_iterations=None): self._start_data_thread() run_menu = False try: self._training_loop(num_iterations) except KeyboardInterrupt: logger.info("Keyboard interrupt at iteration {}.".format( self.net.train_iterations)) run_menu = True finally: self.data_thread.stop() if run_menu: self._menu(num_iterations) def _training_loop(self, num_iterations=None): logger.info( "Starting training for {} iterations.".format(num_iterations)) self.db.experiments.find_one_and_update( filter={'_id': self.experiment_id}, update={ '$set': { 'trainer.requested_train_iterations': num_iterations } }, upsert=True) print(" Update # | Train cost | Secs per update | Source ID") print("------------|--------------|-----------------|-----------") while True: try: self._single_train_iteration() except TrainingError: break except StopIteration: logger.info("Iteration {:d}: Finished training epoch".format( self.net.train_iterations)) for callback in self.epoch_callbacks: callback(self) continue if self.net.train_iterations == num_iterations: break else: self.net.train_iterations += 1 logger.info("Stopped training. Completed {} iterations.".format( self.net.train_iterations)) def _single_train_iteration(self): # Learning rate changes try: self.learning_rate = self.requested_learning_rates[ self.net.train_iterations] except KeyError: pass # Training time0 = time() batch = self.data_thread.get_batch() time1 = time() if batch is None: raise StopIteration() if batch.weights is None: batch.weights = np.ones(batch.target.shape, dtype=np.float32) time2 = time() train_cost = self._get_train_func()(batch.input, batch.target, batch.weights) train_cost = float(train_cost.flatten()[0]) time3 = time() # Save training costs score = { 'experiment_id': self.experiment_id, 'iteration': self.net.train_iterations, 'loss': train_cost, 'source_id': batch.metadata['source_id'] } self.db.train_scores.insert_one(score) time4 = time() duration = time4 - time0 # Print training costs is_best = train_cost <= self.min_train_cost if is_best: self.min_train_cost = train_cost print(" {:>10d} | {}{:>10.6f}{} | {:>10.6f} | {:>3d}".format( self.net.train_iterations, ANSI.BLUE if is_best else "", train_cost, ANSI.ENDC if is_best else "", duration, batch.metadata['source_id'])) # Handle NaN costs if np.isnan(train_cost): msg = "training cost is NaN at iteration {}!".format( self.net.train_iterations) logger.error(msg) raise TrainingError(msg) # Callbacks repeat_callbacks = self.repeat_callbacks[( self.net.train_iterations % self.repeat_callbacks['iteration']) == 0] callbacks = self.callbacks[self.callbacks['iteration'] == self.net.train_iterations] if (len(repeat_callbacks) + len(callbacks)) > 0: # Stop data thread otherwise we get intermittent issues with # the batch generator complaining that it's already running. self.data_thread.stop() self._run_callbacks(repeat_callbacks) self._run_callbacks(callbacks) self._start_data_thread() time5 = time() print( "get_batch={:.3f}; np.ones={:.3f}; train={:.3f}, db={:.3f}; rest={:.3f}" .format(time1 - time0, time2 - time1, time3 - time2, time4 - time3, time5 - time4)) def _run_callbacks(self, df): for callback in df['function']: callback(self) def validate(self): logger.info("Iteration {}: Running validation...".format( self.net.train_iterations)) sources = self.data_thread.data_pipeline.sources output_func = self.net.deterministic_output_func for source_id, source in enumerate(sources): for fold in DATA_FOLD_NAMES: scores_accumulator = None n = 0 batch = self.data_thread.data_pipeline.get_batch( fold=fold, source_id=source_id, reset_iterator=True, validation=True) while True: output = output_func(batch.after_processing.input) scores_for_batch = self.metrics.compute_metrics( output, batch.after_processing.target) scores_for_batch = two_level_dict_to_series( scores_for_batch) if scores_accumulator is None: scores_accumulator = scores_for_batch else: scores_accumulator += scores_for_batch n += 1 batch = self.data_thread.data_pipeline.get_batch( fold=fold, source_id=source_id, validation=True) if batch is None: # end of validation data break scores = scores_accumulator / n scores = scores.astype(float) self.db.validation_scores.insert_one({ 'experiment_id': self.experiment_id, 'iteration': self.net.train_iterations, 'source_id': source_id, 'fold': fold, 'scores': two_level_series_to_dict(scores) }) logger.info(" Finished validation.".format(self.net.train_iterations)) def _get_train_func(self): if self._train_func is None: self._train_func = self._compile_train_func() return self._train_func def _compile_train_func(self): logger.info("Compiling train cost function...") network_input = self.net.symbolic_input() network_output = self.net.symbolic_output(deterministic=False) target_var = ndim_tensor(name='target', ndim=network_output.ndim) mask_var = ndim_tensor(name='mask', ndim=network_output.ndim) loss = self.loss_func(network_output, target_var, mask_var) all_params = get_all_params(self.net.layers[-1], trainable=True) updates = self.updates_func(loss, all_params, learning_rate=self._learning_rate) train_func = theano.function( inputs=[network_input, target_var, mask_var], outputs=loss, updates=updates, on_unused_input='warn', allow_input_downcast=True) logger.info("Done compiling cost function.") return train_func def _menu(self, epochs): # Print menu print("") print("------------------ OPTIONS ------------------") print("d: Enter debugger.") print("s: Save plots and params.") print("q: Quit this experiment.") print("e: Change number of epochs to train this net (currently {}).". format(epochs)) print("c: Continue training.") print("") # Get input selection_str = input("Please enter one or more letters: ") selection_str = selection_str.lower() # Handle input for selection in selection_str: if selection == 'd': import ipdb ipdb.set_trace() elif selection == 's': self.net.save() elif selection == 'q': sure = input("Are you sure you want to quit [Y/n]? ") if sure.lower() != 'n': raise KeyboardInterrupt() elif selection == 'e': new_epochs = input("New number of epochs (or 'None'): ") if new_epochs == 'None': epochs = None else: try: epochs = int(new_epochs) except: print("'{}' not an integer!".format(new_epochs)) elif selection == 'c': break else: print("Selection '{}' not recognised!".format(selection)) break print("Continuing training for {} epochs...".format(epochs)) self.fit(epochs) def report(self): report = {'trainer': copy(self.__dict__)} for attr in [ 'data_pipeline', 'loss_func', 'net', 'repeat_callbacks', 'callbacks', 'epoch_callbacks', 'db', 'experiment_id', 'metrics', '_learning_rate', '_train_func', 'updates_func', 'min_train_cost' ]: report['trainer'].pop(attr, None) report['trainer']['metrics'] = self.metrics.report() report['data'] = self.data_pipeline.report() report['net'] = self.net.report() report['_id'] = self.experiment_id return report def save_params(self): logger.info("Iteration {}: Saving params.".format( self.net.train_iterations)) filename = os.path.join(self.output_path, 'net_params.h5') self.net.save_params(filename=filename) logger.info("Done saving params.") def plot_estimates(self, linewidth=0.5): logger.info("Iteration {}: Plotting estimates.".format( self.net.train_iterations)) sources = self.data_thread.data_pipeline.sources output_func = self.net.deterministic_output_func for source_id, source in enumerate(sources): for fold in DATA_FOLD_NAMES: batch = self.data_thread.data_pipeline.get_batch( fold=fold, source_id=source_id, reset_iterator=True, validation=True) output = output_func(batch.input) num_seq_per_batch, output_seq_length, _ = output.shape num_seqs = min(self.num_seqs_to_plot, num_seq_per_batch) for seq_i in range(num_seqs): fig, axes = plt.subplots(3) # Network output ax = axes[0] ax.set_title('Network output') ax.plot(output[seq_i], linewidth=linewidth) ax.set_xlim([0, output_seq_length]) # Plot target ax = axes[1] ax.set_title('Target') ax.plot(batch.target[seq_i], linewidth=linewidth) ax.set_xlim([0, output_seq_length]) # Plot network input ax = axes[2] ax.set_title('Input') ax.plot(batch.input[seq_i], linewidth=linewidth) ax.set_xlim([0, batch.input.shape[1]]) # Formatting for ax in axes: ax.grid(True) # Save filename = os.path.join( self.output_path, "{:07d}_{}_{}_seq{}.png".format( self.net.train_iterations, fold, source.__class__.__name__, seq_i)) fig.tight_layout() plt.savefig(filename, bbox_inches='tight', dpi=300) plt.close() logger.info("Done plotting.")
class Trainer(object): def __init__(self, net, data_pipeline, experiment_id, mongo_host=None, mongo_db='neuralnilm', loss_func=squared_error, loss_aggregation_mode='mean', updates_func=nesterov_momentum, updates_func_kwards=None, learning_rates=None, callbacks=None, repeat_callbacks=None, epoch_callbacks=None, metrics=None, num_seqs_to_plot=8): """ Parameters ---------- experiment_id : list of strings concatenated together with an underscore. Defines output path mongo_host : Address of PyMongo database. See http://docs.mongodb.org/manual/reference/connection-string/ callbacks : list of 2-tuples (<iteration>, <function>) Function must accept a single argument: this Trainer object. repeat_callbacks : list of 2-tuples (<iteration>, <function>) Function must accept a single argument: this Trainer object. For example, to run validation every 100 iterations, set `repeat_callbacks=[(100, Trainer.validate)]`. epoch_callbacks : list of functions Functions are called at the end of each training epoch. metrics : neuralnilm.Metrics object Run during `Trainer.validation()` """ # Database if mongo_host is None: mongo_host = config.get("MongoDB", "address") mongo_client = MongoClient(mongo_host) self.db = mongo_client[mongo_db] # Training and validation state self.requested_learning_rates = ( {0: 1E-2} if learning_rates is None else learning_rates) self.experiment_id = "_".join(experiment_id) self._train_func = None self.metrics = metrics self.net = net self.data_pipeline = data_pipeline self.min_train_cost = float("inf") self.num_seqs_to_plot = num_seqs_to_plot # Check if this experiment already exists in database delete_or_quit = None if self.db.experiments.find_one({'_id': self.experiment_id}): delete_or_quit = raw_input( "Database already has an experiment with _id == {}." " Should the old experiment be deleted (d)" " (both from the database and from disk)? Or quit (q)?" " Or append (a) '_try<i>` string to _id? [A/q/d] " .format(self.experiment_id)).lower() if delete_or_quit == 'd': logger.info("Deleting documents for old experiment.") self.db.experiments.delete_one({'_id': self.experiment_id}) for collection in COLLECTIONS: self.db[collection].delete_many( {'experiment_id': self.experiment_id}) elif delete_or_quit == 'q': raise KeyboardInterrupt() else: self.modification_since_last_try = raw_input( "Enter a short description of what has changed since" " the last try: ") try_i = 2 while True: candidate_id = self.experiment_id + '_try' + str(try_i) if self.db.experiments.find_one({'_id': candidate_id}): try_i += 1 else: self.experiment_id = candidate_id logger.info("experiment_id set to {}" .format(self.experiment_id)) break # Output path path_list = [config.get('Paths', 'output')] path_list += self.experiment_id.split('_') self.output_path = os.path.join(*path_list) try: os.makedirs(self.output_path) except OSError as os_error: if os_error.errno == 17: # file exists logger.info("Directory exists = '{}'".format(self.output_path)) if delete_or_quit == 'd': logger.info(" Deleting directory.") shutil.rmtree(self.output_path) os.makedirs(self.output_path) else: logger.info(" Re-using directory.") else: raise # Loss and updates def aggregated_loss_func(prediction, target, weights=None): loss = loss_func(prediction, target) return aggregate(loss, mode=loss_aggregation_mode, weights=weights) self.loss_func_name = loss_func.__name__ self.loss_func = aggregated_loss_func self.updates_func_name = updates_func.__name__ self.updates_func_kwards = none_to_dict(updates_func_kwards) self.updates_func = partial(updates_func, **self.updates_func_kwards) self.loss_aggregation_mode = loss_aggregation_mode # Learning rate # Set _learning_rate to -1 so when we set self.learning_rate # during the training loop, it will be logger correctly. self._learning_rate = theano.shared(sfloatX(-1), name='learning_rate') # Callbacks def callbacks_dataframe(lst): return pd.DataFrame(lst, columns=['iteration', 'function']) self.callbacks = callbacks_dataframe(callbacks) self.repeat_callbacks = callbacks_dataframe(repeat_callbacks) self.epoch_callbacks = none_to_list(epoch_callbacks) def submit_report(self, additional_report_contents=None): """Submit report to database. Parameters ---------- additional_report_contents : list of tuples of (list of keys, dict) e.g. >>> contents = [(['data'], {'activations': LOADER_CONFIG})] >>> trainer.submit_report(additional_report_contents=contents) """ report = self.report() if additional_report_contents is not None: for (keys, update) in additional_report_contents: dict_to_update = report for key in keys: dict_to_update = dict_to_update[key] dict_to_update.update(update) report = sanitise_dict_for_mongo(report) self.db.experiments.insert_one(report) return report @property def learning_rate(self): return self._learning_rate.get_value().flatten()[0] @learning_rate.setter def learning_rate(self, rate): rate = sfloatX(rate) if rate == self.learning_rate: logger.info( "Iteration {:d}: Attempted to change learning rate to {:.1E}" " but that is already the value!" .format(self.net.train_iterations, rate)) else: logger.info( "Iteration {:d}: Change learning rate to {:.1E}" .format(self.net.train_iterations, rate)) self.db.experiments.find_one_and_update( filter={'_id': self.experiment_id}, update={ '$set': {'trainer.actual_learning_rates.{:d}' .format(self.net.train_iterations): float(rate)}}, upsert=True ) self._learning_rate.set_value(rate) def _start_data_thread(self): self.data_thread = DataThread(self.data_pipeline) self.data_thread.start() def fit(self, num_iterations=None): self._start_data_thread() run_menu = False try: self._training_loop(num_iterations) except KeyboardInterrupt: logger.info("Keyboard interrupt at iteration {}." .format(self.net.train_iterations)) run_menu = True finally: self.data_thread.stop() if run_menu: self._menu(num_iterations) def _training_loop(self, num_iterations=None): logger.info("Starting training for {} iterations." .format(num_iterations)) self.db.experiments.find_one_and_update( filter={'_id': self.experiment_id}, update={ '$set': {'trainer.requested_train_iterations': num_iterations} }, upsert=True ) print(" Update # | Train cost | Secs per update | Source ID") print("------------|--------------|-----------------|-----------") while True: try: self._single_train_iteration() except TrainingError: break except StopIteration: logger.info("Iteration {:d}: Finished training epoch" .format(self.net.train_iterations)) for callback in self.epoch_callbacks: callback(self) continue if self.net.train_iterations == num_iterations: break else: self.net.train_iterations += 1 logger.info("Stopped training. Completed {} iterations." .format(self.net.train_iterations)) def _single_train_iteration(self): # Learning rate changes try: self.learning_rate = self.requested_learning_rates[ self.net.train_iterations] except KeyError: pass # Training time0 = time() batch = self.data_thread.get_batch() time1 = time() if batch is None: raise StopIteration() if batch.weights is None: batch.weights = np.ones(batch.target.shape, dtype=np.float32) time2 = time() train_cost = self._get_train_func()( batch.input, batch.target, batch.weights) train_cost = float(train_cost.flatten()[0]) time3 = time() # Save training costs score = { 'experiment_id': self.experiment_id, 'iteration': self.net.train_iterations, 'loss': train_cost, 'source_id': batch.metadata['source_id'] } self.db.train_scores.insert_one(score) time4 = time() duration = time4 - time0 # Print training costs is_best = train_cost <= self.min_train_cost if is_best: self.min_train_cost = train_cost print(" {:>10d} | {}{:>10.6f}{} | {:>10.6f} | {:>3d}".format( self.net.train_iterations, ANSI.BLUE if is_best else "", train_cost, ANSI.ENDC if is_best else "", duration, batch.metadata['source_id'])) # Handle NaN costs if np.isnan(train_cost): msg = "training cost is NaN at iteration {}!".format( self.net.train_iterations) logger.error(msg) raise TrainingError(msg) # Callbacks repeat_callbacks = self.repeat_callbacks[ (self.net.train_iterations % self.repeat_callbacks['iteration']) == 0] callbacks = self.callbacks[ self.callbacks['iteration'] == self.net.train_iterations] if (len(repeat_callbacks) + len(callbacks)) > 0: # Stop data thread otherwise we get intermittent issues with # the batch generator complaining that it's already running. self.data_thread.stop() self._run_callbacks(repeat_callbacks) self._run_callbacks(callbacks) self._start_data_thread() time5 = time() print( "get_batch={:.3f}; np.ones={:.3f}; train={:.3f}, db={:.3f}; rest={:.3f}" .format(time1-time0, time2-time1, time3-time2, time4-time3, time5-time4)) def _run_callbacks(self, df): for callback in df['function']: callback(self) def validate(self): logger.info("Iteration {}: Running validation..." .format(self.net.train_iterations)) sources = self.data_thread.data_pipeline.sources output_func = self.net.deterministic_output_func for source_id, source in enumerate(sources): for fold in DATA_FOLD_NAMES: scores_accumulator = None n = 0 batch = self.data_thread.data_pipeline.get_batch( fold=fold, source_id=source_id, reset_iterator=True, validation=True) while True: output = output_func(batch.after_processing.input) scores_for_batch = self.metrics.compute_metrics( output, batch.after_processing.target) scores_for_batch = two_level_dict_to_series( scores_for_batch) if scores_accumulator is None: scores_accumulator = scores_for_batch else: scores_accumulator += scores_for_batch n += 1 batch = self.data_thread.data_pipeline.get_batch( fold=fold, source_id=source_id, validation=True) if batch is None: # end of validation data break scores = scores_accumulator / n scores = scores.astype(float) self.db.validation_scores.insert_one({ 'experiment_id': self.experiment_id, 'iteration': self.net.train_iterations, 'source_id': source_id, 'fold': fold, 'scores': two_level_series_to_dict(scores) }) logger.info(" Finished validation.".format(self.net.train_iterations)) def _get_train_func(self): if self._train_func is None: self._train_func = self._compile_train_func() return self._train_func def _compile_train_func(self): logger.info("Compiling train cost function...") network_input = self.net.symbolic_input() network_output = self.net.symbolic_output(deterministic=False) target_var = ndim_tensor(name='target', ndim=network_output.ndim) mask_var = ndim_tensor(name='mask', ndim=network_output.ndim) loss = self.loss_func(network_output, target_var, mask_var) all_params = get_all_params(self.net.layers[-1], trainable=True) updates = self.updates_func( loss, all_params, learning_rate=self._learning_rate) train_func = theano.function( inputs=[network_input, target_var, mask_var], outputs=loss, updates=updates, on_unused_input='warn', allow_input_downcast=True) logger.info("Done compiling cost function.") return train_func def _menu(self, epochs): # Print menu print("") print("------------------ OPTIONS ------------------") print("d: Enter debugger.") print("s: Save plots and params.") print("q: Quit this experiment.") print("e: Change number of epochs to train this net (currently {})." .format(epochs)) print("c: Continue training.") print("") # Get input selection_str = raw_input("Please enter one or more letters: ") selection_str = selection_str.lower() # Handle input for selection in selection_str: if selection == 'd': import ipdb ipdb.set_trace() elif selection == 's': self.net.save() elif selection == 'q': sure = raw_input("Are you sure you want to quit [Y/n]? ") if sure.lower() != 'n': raise KeyboardInterrupt() elif selection == 'e': new_epochs = raw_input("New number of epochs (or 'None'): ") if new_epochs == 'None': epochs = None else: try: epochs = int(new_epochs) except: print("'{}' not an integer!".format(new_epochs)) elif selection == 'c': break else: print("Selection '{}' not recognised!".format(selection)) break print("Continuing training for {} epochs...".format(epochs)) self.fit(epochs) def report(self): report = {'trainer': copy(self.__dict__)} for attr in [ 'data_pipeline', 'loss_func', 'net', 'repeat_callbacks', 'callbacks', 'epoch_callbacks', 'db', 'experiment_id', 'metrics', '_learning_rate', '_train_func', 'updates_func', 'min_train_cost']: report['trainer'].pop(attr, None) report['trainer']['metrics'] = self.metrics.report() report['data'] = self.data_pipeline.report() report['net'] = self.net.report() report['_id'] = self.experiment_id return report def save_params(self): logger.info( "Iteration {}: Saving params.".format(self.net.train_iterations)) filename = os.path.join(self.output_path, 'net_params.h5') self.net.save_params(filename=filename) logger.info("Done saving params.") def plot_estimates(self, linewidth=0.5): logger.info( "Iteration {}: Plotting estimates." .format(self.net.train_iterations)) sources = self.data_thread.data_pipeline.sources output_func = self.net.deterministic_output_func for source_id, source in enumerate(sources): for fold in DATA_FOLD_NAMES: batch = self.data_thread.data_pipeline.get_batch( fold=fold, source_id=source_id, reset_iterator=True, validation=True) output = output_func(batch.input) num_seq_per_batch, output_seq_length, _ = output.shape num_seqs = min(self.num_seqs_to_plot, num_seq_per_batch) for seq_i in range(num_seqs): fig, axes = plt.subplots(3) # Network output ax = axes[0] ax.set_title('Network output') ax.plot(output[seq_i], linewidth=linewidth) ax.set_xlim([0, output_seq_length]) # Plot target ax = axes[1] ax.set_title('Target') ax.plot(batch.target[seq_i], linewidth=linewidth) ax.set_xlim([0, output_seq_length]) # Plot network input ax = axes[2] ax.set_title('Input') ax.plot(batch.input[seq_i], linewidth=linewidth) ax.set_xlim([0, batch.input.shape[1]]) # Formatting for ax in axes: ax.grid(True) # Save filename = os.path.join( self.output_path, "{:07d}_{}_{}_seq{}.png".format( self.net.train_iterations, fold, source.__class__.__name__, seq_i)) fig.tight_layout() plt.savefig(filename, bbox_inches='tight', dpi=300) plt.close() logger.info("Done plotting.")
def _start_data_thread(self): self.data_thread = DataThread(self.data_pipeline) self.data_thread.start()