class Dump(SimpleExtension): """Dumps the state of the main loop. Makes a `SAVED_TO` record in the log with the dumping destination in the case of success and ``None`` in the case of failure. Parameters ---------- state_path : str The folder to dump the state to. Will be created it does not exist. Notes ----- Requires the model to be a Brick or a list of Bricks. """ def __init__(self, state_path, **kwargs): kwargs.setdefault("after_training", True) super(Dump, self).__init__(**kwargs) self.manager = MainLoopDumpManager(state_path) def do(self, callback_name, *args, **kwargs): try: self.main_loop.log.current_row[SAVED_TO] = ( self.manager.folder) self.manager.dump(self.main_loop) except Exception: self.main_loop.log.current_row[SAVED_TO] = None raise
class LoadFromDump(TrainingExtension): """Loads a dump into the main loop. Makes a `LOADED_FROM` record in the log with the dump path. Parameters ---------- state_path : str The path to the folder with dump. Notes ----- Requires the model to be a Brick or a list of Bricks. """ def __init__(self, state_path, **kwargs): super(LoadFromDump, self).__init__(**kwargs) self.manager = MainLoopDumpManager(state_path) def before_training(self): if not os.path.exists(self.manager.folder): logger.info("No dump found") return logger.info("Loading the state from {} into the main loop" .format(self.manager.folder)) try: self.manager.load_to(self.main_loop) self.main_loop.log.current_row[LOADED_FROM] = self.manager.folder except Exception: reraise_as("Failed to load the state")
def test_main_loop_dump_manager(): def assert_equal(main_loop1, main_loop2, check_log=True): """Check if two main loop objects are equal. Notes ----- Corrupts the iteration state! """ W1 = (main_loop1.model.get_top_bricks()[0].linear_transformations[0] .params[0].get_value()) W2 = (main_loop2.model.get_top_bricks()[0].linear_transformations[0] .params[0].get_value()) assert numpy.all(W1 == W2) if check_log: assert sorted(list(main_loop1.log)) == sorted(list(main_loop2.log)) assert numpy.all( next(main_loop1.epoch_iterator)["numbers"] == next(main_loop2.epoch_iterator)["numbers"]) folder = tempfile.mkdtemp() folder2 = tempfile.mkdtemp() main_loop1 = sqrt_example(folder, 17) assert main_loop1.log.status.epochs_done == 3 assert main_loop1.log.status.iterations_done == 17 # Test loading from the folder where `main_loop` is saved main_loop2 = sqrt_example(folder2, 1) manager = MainLoopDumpManager(folder) manager.load_to(main_loop2) assert_equal(main_loop1, main_loop2) # Reload because `main_loop2` is corrupted by `assert_equal` main_loop2 = sqrt_example(folder2, 1) manager.load_to(main_loop2) # Continue until 33 iterations are done main_loop2.find_extension("FinishAfter").set_conditions(after_n_batches=33) main_loop2.run() assert main_loop2.log.status.iterations_done == 33 # Compare with a main loop after continuous 33 iterations main_loop3 = sqrt_example(folder, 33) assert main_loop3.log.status.iterations_done == 33 assert_equal(main_loop2, main_loop3, check_log=False)
def train_model(cost, train_stream, valid_stream, valid_freq, valid_rare, load_location=None, save_location=None): cost.name = 'nll' perplexity = 2 ** (cost / tensor.log(2)) perplexity.name = 'ppl' # Define the model model = Model(cost) # Load the parameters from a dumped model if load_location is not None: logger.info('Loading parameters...') model.set_param_values(load_parameter_values(load_location)) cg = ComputationGraph(cost) algorithm = GradientDescent(cost=cost, step_rule=Scale(learning_rate=0.01), params=cg.parameters) main_loop = MainLoop( model=model, data_stream=train_stream, algorithm=algorithm, extensions=[ DataStreamMonitoring([cost, perplexity], valid_stream, prefix='valid_all', every_n_batches=5000), # Overfitting of rare words occurs between 3000 and 4000 iterations DataStreamMonitoring([cost, perplexity], valid_rare, prefix='valid_rare', every_n_batches=500), DataStreamMonitoring([cost, perplexity], valid_freq, prefix='valid_frequent', every_n_batches=5000), Printing(every_n_batches=500) ] ) main_loop.run() # Save the main loop if save_location is not None: logger.info('Saving the main loop...') dump_manager = MainLoopDumpManager(save_location) dump_manager.dump(main_loop) logger.info('Saved')
def train_model(cost, error_rate, train_stream, load_location=None, save_location=None): cost.name = "Cross_entropy" error_rate.name = 'Error_rate' # Define the model model = Model(cost) # Load the parameters from a dumped model if load_location is not None: logger.info('Loading parameters...') model.set_param_values(load_parameter_values(load_location)) cg = ComputationGraph(cost) step_rule = Momentum(learning_rate=0.1, momentum=0.9) algorithm = GradientDescent(cost=cost, step_rule=step_rule, params=cg.parameters) main_loop = MainLoop( model=model, data_stream=train_stream, algorithm=algorithm, extensions=[ # DataStreamMonitoring([cost], test_stream, prefix='test', # after_epoch=False, every_n_epochs=10), DataStreamMonitoring([cost], train_stream, prefix='train', after_epoch=True), Printing(after_epoch=True) ] ) main_loop.run() # Save the main loop if save_location is not None: logger.info('Saving the main loop...') dump_manager = MainLoopDumpManager(save_location) dump_manager.dump(main_loop) logger.info('Saved')
def test_main_loop_dump_manager(): def assert_equal(main_loop1, main_loop2, check_log=True): """Check if two main loop objects are equal. Notes ----- Corrupts the iteration state! """ W1 = (main_loop1.model.get_top_bricks() [0].linear_transformations[0].params[0].get_value()) W2 = (main_loop2.model.get_top_bricks() [0].linear_transformations[0].params[0].get_value()) assert numpy.all(W1 == W2) if check_log: assert sorted(list(main_loop1.log)) == sorted(list(main_loop2.log)) assert numpy.all( next(main_loop1.epoch_iterator)["numbers"] == next( main_loop2.epoch_iterator)["numbers"]) folder = tempfile.mkdtemp() folder2 = tempfile.mkdtemp() main_loop1 = sqrt_example(folder, 17) assert main_loop1.log.status.epochs_done == 3 assert main_loop1.log.status.iterations_done == 17 # Test loading from the folder where `main_loop` is saved main_loop2 = sqrt_example(folder2, 1) manager = MainLoopDumpManager(folder) manager.load_to(main_loop2) assert_equal(main_loop1, main_loop2) # Reload because `main_loop2` is corrupted by `assert_equal` main_loop2 = sqrt_example(folder2, 1) manager.load_to(main_loop2) # Continue until 33 iterations are done main_loop2.find_extension("FinishAfter").set_conditions(after_n_batches=33) main_loop2.run() assert main_loop2.log.status.iterations_done == 33 # Compare with a main loop after continuous 33 iterations main_loop3 = sqrt_example(folder, 33) assert main_loop3.log.status.iterations_done == 33 assert_equal(main_loop2, main_loop3, check_log=False)
def __init__(self, state_path, **kwargs): kwargs.setdefault("after_training", True) super(Dump, self).__init__(**kwargs) self.manager = MainLoopDumpManager(state_path)
def __init__(self, state_path, **kwargs): super(LoadFromDump, self).__init__(**kwargs) self.manager = MainLoopDumpManager(state_path)
def dump(pickle_path, dump_path): with change_recursion_limit(config.recursion_limit): main_loop = cPickle.load(open(pickle_path, "rb")) MainLoopDumpManager(dump_path).dump(main_loop)
def dump(pickle_path, dump_path, rec_limit=None): if rec_limit: sys.setrecursionlimit(rec_limit) main_loop = dill.load(open(pickle_path, "rb")) MainLoopDumpManager(dump_path).dump(main_loop)