Example #1
0
class Dump(SimpleExtension):
    """Dumps the state of the main loop.

    Makes a `SAVED_TO` record in the log with the dumping destination
    in the case of success and ``None`` in the case of failure.

    Parameters
    ----------
    state_path : str
        The folder to dump the state to. Will be created it does not
        exist.

    Notes
    -----
    Requires the model to be a Brick or a list of Bricks.

    """
    def __init__(self, state_path, **kwargs):
        kwargs.setdefault("after_training", True)
        super(Dump, self).__init__(**kwargs)
        self.manager = MainLoopDumpManager(state_path)

    def do(self, callback_name, *args, **kwargs):
        try:
            self.main_loop.log.current_row[SAVED_TO] = (
                self.manager.folder)
            self.manager.dump(self.main_loop)
        except Exception:
            self.main_loop.log.current_row[SAVED_TO] = None
            raise
Example #2
0
class LoadFromDump(TrainingExtension):
    """Loads a dump into the main loop.

    Makes a `LOADED_FROM` record in the log with the dump path.

    Parameters
    ----------
    state_path : str
        The path to the folder with dump.

    Notes
    -----
    Requires the model to be a Brick or a list of Bricks.

    """
    def __init__(self, state_path, **kwargs):
        super(LoadFromDump, self).__init__(**kwargs)
        self.manager = MainLoopDumpManager(state_path)

    def before_training(self):
        if not os.path.exists(self.manager.folder):
            logger.info("No dump found")
            return
        logger.info("Loading the state from {} into the main loop"
                    .format(self.manager.folder))
        try:
            self.manager.load_to(self.main_loop)
            self.main_loop.log.current_row[LOADED_FROM] = self.manager.folder
        except Exception:
            reraise_as("Failed to load the state")
Example #3
0
def test_main_loop_dump_manager():
    def assert_equal(main_loop1, main_loop2, check_log=True):
        """Check if two main loop objects are equal.

        Notes
        -----
        Corrupts the iteration state!

        """
        W1 = (main_loop1.model.get_top_bricks()[0].linear_transformations[0]
                              .params[0].get_value())
        W2 = (main_loop2.model.get_top_bricks()[0].linear_transformations[0]
                              .params[0].get_value())
        assert numpy.all(W1 == W2)
        if check_log:
            assert sorted(list(main_loop1.log)) == sorted(list(main_loop2.log))
        assert numpy.all(
            next(main_loop1.epoch_iterator)["numbers"] ==
            next(main_loop2.epoch_iterator)["numbers"])

    folder = tempfile.mkdtemp()
    folder2 = tempfile.mkdtemp()

    main_loop1 = sqrt_example(folder, 17)
    assert main_loop1.log.status.epochs_done == 3
    assert main_loop1.log.status.iterations_done == 17

    # Test loading from the folder where `main_loop` is saved
    main_loop2 = sqrt_example(folder2, 1)
    manager = MainLoopDumpManager(folder)
    manager.load_to(main_loop2)
    assert_equal(main_loop1, main_loop2)

    # Reload because `main_loop2` is corrupted by `assert_equal`
    main_loop2 = sqrt_example(folder2, 1)
    manager.load_to(main_loop2)
    # Continue until 33 iterations are done
    main_loop2.find_extension("FinishAfter").set_conditions(after_n_batches=33)
    main_loop2.run()
    assert main_loop2.log.status.iterations_done == 33

    # Compare with a main loop after continuous 33 iterations
    main_loop3 = sqrt_example(folder, 33)
    assert main_loop3.log.status.iterations_done == 33
    assert_equal(main_loop2, main_loop3, check_log=False)
Example #4
0
def train_model(cost, train_stream, valid_stream, valid_freq, valid_rare,
                load_location=None, save_location=None):
    cost.name = 'nll'
    perplexity = 2 ** (cost / tensor.log(2))
    perplexity.name = 'ppl'

    # Define the model
    model = Model(cost)

    # Load the parameters from a dumped model
    if load_location is not None:
        logger.info('Loading parameters...')
        model.set_param_values(load_parameter_values(load_location))

    cg = ComputationGraph(cost)
    algorithm = GradientDescent(cost=cost, step_rule=Scale(learning_rate=0.01),
                                params=cg.parameters)
    main_loop = MainLoop(
        model=model,
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=[
            DataStreamMonitoring([cost, perplexity], valid_stream,
                                 prefix='valid_all', every_n_batches=5000),
            # Overfitting of rare words occurs between 3000 and 4000 iterations
            DataStreamMonitoring([cost, perplexity], valid_rare,
                                 prefix='valid_rare', every_n_batches=500),
            DataStreamMonitoring([cost, perplexity], valid_freq,
                                 prefix='valid_frequent',
                                 every_n_batches=5000),
            Printing(every_n_batches=500)
        ]
    )
    main_loop.run()

    # Save the main loop
    if save_location is not None:
        logger.info('Saving the main loop...')
        dump_manager = MainLoopDumpManager(save_location)
        dump_manager.dump(main_loop)
        logger.info('Saved')
def train_model(cost, error_rate, train_stream,
                load_location=None, save_location=None):

    cost.name = "Cross_entropy"
    error_rate.name = 'Error_rate'

    # Define the model
    model = Model(cost)

    # Load the parameters from a dumped model
    if load_location is not None:
        logger.info('Loading parameters...')
        model.set_param_values(load_parameter_values(load_location))

    cg = ComputationGraph(cost)
    step_rule = Momentum(learning_rate=0.1, momentum=0.9)
    algorithm = GradientDescent(cost=cost, step_rule=step_rule,
                                params=cg.parameters)
    main_loop = MainLoop(
        model=model,
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=[
            # DataStreamMonitoring([cost], test_stream, prefix='test',
            #                      after_epoch=False, every_n_epochs=10),
            DataStreamMonitoring([cost], train_stream, prefix='train',
                                 after_epoch=True),
            Printing(after_epoch=True)
        ]
    )
    main_loop.run()

    # Save the main loop
    if save_location is not None:
        logger.info('Saving the main loop...')
        dump_manager = MainLoopDumpManager(save_location)
        dump_manager.dump(main_loop)
        logger.info('Saved')
Example #6
0
def test_main_loop_dump_manager():
    def assert_equal(main_loop1, main_loop2, check_log=True):
        """Check if two main loop objects are equal.

        Notes
        -----
        Corrupts the iteration state!

        """
        W1 = (main_loop1.model.get_top_bricks()
              [0].linear_transformations[0].params[0].get_value())
        W2 = (main_loop2.model.get_top_bricks()
              [0].linear_transformations[0].params[0].get_value())
        assert numpy.all(W1 == W2)
        if check_log:
            assert sorted(list(main_loop1.log)) == sorted(list(main_loop2.log))
        assert numpy.all(
            next(main_loop1.epoch_iterator)["numbers"] == next(
                main_loop2.epoch_iterator)["numbers"])

    folder = tempfile.mkdtemp()
    folder2 = tempfile.mkdtemp()

    main_loop1 = sqrt_example(folder, 17)
    assert main_loop1.log.status.epochs_done == 3
    assert main_loop1.log.status.iterations_done == 17

    # Test loading from the folder where `main_loop` is saved
    main_loop2 = sqrt_example(folder2, 1)
    manager = MainLoopDumpManager(folder)
    manager.load_to(main_loop2)
    assert_equal(main_loop1, main_loop2)

    # Reload because `main_loop2` is corrupted by `assert_equal`
    main_loop2 = sqrt_example(folder2, 1)
    manager.load_to(main_loop2)
    # Continue until 33 iterations are done
    main_loop2.find_extension("FinishAfter").set_conditions(after_n_batches=33)
    main_loop2.run()
    assert main_loop2.log.status.iterations_done == 33

    # Compare with a main loop after continuous 33 iterations
    main_loop3 = sqrt_example(folder, 33)
    assert main_loop3.log.status.iterations_done == 33
    assert_equal(main_loop2, main_loop3, check_log=False)
Example #7
0
 def __init__(self, state_path, **kwargs):
     kwargs.setdefault("after_training", True)
     super(Dump, self).__init__(**kwargs)
     self.manager = MainLoopDumpManager(state_path)
Example #8
0
 def __init__(self, state_path, **kwargs):
     super(LoadFromDump, self).__init__(**kwargs)
     self.manager = MainLoopDumpManager(state_path)
Example #9
0
 def __init__(self, state_path, **kwargs):
     kwargs.setdefault("after_training", True)
     super(Dump, self).__init__(**kwargs)
     self.manager = MainLoopDumpManager(state_path)
Example #10
0
 def __init__(self, state_path, **kwargs):
     super(LoadFromDump, self).__init__(**kwargs)
     self.manager = MainLoopDumpManager(state_path)
Example #11
0
def dump(pickle_path, dump_path):
    with change_recursion_limit(config.recursion_limit):
        main_loop = cPickle.load(open(pickle_path, "rb"))
    MainLoopDumpManager(dump_path).dump(main_loop)
Example #12
0
def dump(pickle_path, dump_path, rec_limit=None):
    if rec_limit:
        sys.setrecursionlimit(rec_limit)
    main_loop = dill.load(open(pickle_path, "rb"))
    MainLoopDumpManager(dump_path).dump(main_loop)