Exemplo n.º 1
0
class Dump(SimpleExtension):
    """Dumps the state of the main loop.

    Makes a `SAVED_TO` record in the log with the dumping destination
    in the case of success and ``None`` in the case of failure.

    Parameters
    ----------
    state_path : str
        The folder to dump the state to. Will be created it does not
        exist.

    Notes
    -----
    Requires the model to be a Brick or a list of Bricks.

    """
    def __init__(self, state_path, **kwargs):
        kwargs.setdefault("after_training", True)
        super(Dump, self).__init__(**kwargs)
        self.manager = MainLoopDumpManager(state_path)

    def do(self, callback_name, *args, **kwargs):
        try:
            self.main_loop.log.current_row[SAVED_TO] = (self.manager.folder)
            self.manager.dump(self.main_loop)
        except Exception:
            self.main_loop.log.current_row[SAVED_TO] = None
            raise
Exemplo n.º 2
0
class Dump(SimpleExtension):
    """Dumps the state of the main loop.

    Makes a `SAVED_TO` record in the log with the dumping destination
    in the case of success and ``None`` in the case of failure.

    Parameters
    ----------
    state_path : str
        The folder to dump the state to. Will be created it does not
        exist.

    Notes
    -----
    Requires the model to be a Brick or a list of Bricks.

    """
    def __init__(self, state_path, **kwargs):
        kwargs.setdefault("after_training", True)
        super(Dump, self).__init__(**kwargs)
        self.manager = MainLoopDumpManager(state_path)

    def do(self, callback_name, *args, **kwargs):
        try:
            self.main_loop.log.current_row[SAVED_TO] = (
                self.manager.folder)
            self.manager.dump(self.main_loop)
        except Exception:
            self.main_loop.log.current_row[SAVED_TO] = None
            raise
Exemplo n.º 3
0
def train_model(cost,
                train_stream,
                valid_stream,
                valid_freq,
                valid_rare,
                load_location=None,
                save_location=None):
    cost.name = 'nll'
    perplexity = 2**(cost / tensor.log(2))
    perplexity.name = 'ppl'

    # Define the model
    model = Model(cost)

    # Load the parameters from a dumped model
    if load_location is not None:
        logger.info('Loading parameters...')
        model.set_param_values(load_parameter_values(load_location))

    cg = ComputationGraph(cost)
    algorithm = GradientDescent(cost=cost,
                                step_rule=Scale(learning_rate=0.01),
                                params=cg.parameters)
    main_loop = MainLoop(
        model=model,
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=[
            DataStreamMonitoring([cost, perplexity],
                                 valid_stream,
                                 prefix='valid_all',
                                 every_n_batches=5000),
            # Overfitting of rare words occurs between 3000 and 4000 iterations
            DataStreamMonitoring([cost, perplexity],
                                 valid_rare,
                                 prefix='valid_rare',
                                 every_n_batches=500),
            DataStreamMonitoring([cost, perplexity],
                                 valid_freq,
                                 prefix='valid_frequent',
                                 every_n_batches=5000),
            Printing(every_n_batches=500)
        ])
    main_loop.run()

    # Save the main loop
    if save_location is not None:
        logger.info('Saving the main loop...')
        dump_manager = MainLoopDumpManager(save_location)
        dump_manager.dump(main_loop)
        logger.info('Saved')
Exemplo n.º 4
0
def train_model(cost,
                error_rate,
                train_stream,
                load_location=None,
                save_location=None):

    cost.name = "Cross_entropy"
    error_rate.name = 'Error_rate'

    # Define the model
    model = Model(cost)

    # Load the parameters from a dumped model
    if load_location is not None:
        logger.info('Loading parameters...')
        model.set_param_values(load_parameter_values(load_location))

    cg = ComputationGraph(cost)
    step_rule = Momentum(learning_rate=0.1, momentum=0.9)
    algorithm = GradientDescent(cost=cost,
                                step_rule=step_rule,
                                params=cg.parameters)
    main_loop = MainLoop(
        model=model,
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=[
            # DataStreamMonitoring([cost], test_stream, prefix='test',
            #                      after_epoch=False, every_n_epochs=10),
            DataStreamMonitoring([cost],
                                 train_stream,
                                 prefix='train',
                                 after_epoch=True),
            Printing(after_epoch=True)
        ])
    main_loop.run()

    # Save the main loop
    if save_location is not None:
        logger.info('Saving the main loop...')
        dump_manager = MainLoopDumpManager(save_location)
        dump_manager.dump(main_loop)
        logger.info('Saved')
Exemplo n.º 5
0
def train_model(cost, train_stream, valid_stream, valid_freq, valid_rare,
                load_location=None, save_location=None):
    cost.name = 'nll'
    perplexity = 2 ** (cost / tensor.log(2))
    perplexity.name = 'ppl'

    # Define the model
    model = Model(cost)

    # Load the parameters from a dumped model
    if load_location is not None:
        logger.info('Loading parameters...')
        model.set_param_values(load_parameter_values(load_location))

    cg = ComputationGraph(cost)
    algorithm = GradientDescent(cost=cost, step_rule=Scale(learning_rate=0.01),
                                params=cg.parameters)
    main_loop = MainLoop(
        model=model,
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=[
            DataStreamMonitoring([cost, perplexity], valid_stream,
                                 prefix='valid_all', every_n_batches=5000),
            # Overfitting of rare words occurs between 3000 and 4000 iterations
            DataStreamMonitoring([cost, perplexity], valid_rare,
                                 prefix='valid_rare', every_n_batches=500),
            DataStreamMonitoring([cost, perplexity], valid_freq,
                                 prefix='valid_frequent',
                                 every_n_batches=5000),
            Printing(every_n_batches=500)
        ]
    )
    main_loop.run()

    # Save the main loop
    if save_location is not None:
        logger.info('Saving the main loop...')
        dump_manager = MainLoopDumpManager(save_location)
        dump_manager.dump(main_loop)
        logger.info('Saved')
def train_model(cost, error_rate, train_stream,
                load_location=None, save_location=None):

    cost.name = "Cross_entropy"
    error_rate.name = 'Error_rate'

    # Define the model
    model = Model(cost)

    # Load the parameters from a dumped model
    if load_location is not None:
        logger.info('Loading parameters...')
        model.set_param_values(load_parameter_values(load_location))

    cg = ComputationGraph(cost)
    step_rule = Momentum(learning_rate=0.1, momentum=0.9)
    algorithm = GradientDescent(cost=cost, step_rule=step_rule,
                                params=cg.parameters)
    main_loop = MainLoop(
        model=model,
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=[
            # DataStreamMonitoring([cost], test_stream, prefix='test',
            #                      after_epoch=False, every_n_epochs=10),
            DataStreamMonitoring([cost], train_stream, prefix='train',
                                 after_epoch=True),
            Printing(after_epoch=True)
        ]
    )
    main_loop.run()

    # Save the main loop
    if save_location is not None:
        logger.info('Saving the main loop...')
        dump_manager = MainLoopDumpManager(save_location)
        dump_manager.dump(main_loop)
        logger.info('Saved')