예제 #1
0
파일: trainers.py 프로젝트: 52nlp/deepy
    def __init__(self, network, config=None, method=None):

        if method:
            logging.info("changing optimization method to '%s'" % method)
            if not config:
                config = TrainerConfig()
            elif isinstance(config, dict):
                config = TrainerConfig(config)
            config.method = method

        super(GeneralNeuralTrainer, self).__init__(network, config)

        logging.info('compiling %s learning function', self.__class__.__name__)

        network_updates = list(network.updates) + list(network.training_updates)
        learning_updates = list(self.learning_updates())
        update_list = network_updates + learning_updates
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))

        self.learning_func = theano.function(
            network.input_variables + network.target_variables,
            map(lambda v: theano.Out(v, borrow=True), self.training_variables),
            updates=update_list, allow_input_downcast=True,
            mode=self.config.get("theano_mode", None))
예제 #2
0
    def __init__(self, network, method=None, config=None, annealer=None, validator=None):

        if method:
            logging.info("changing optimization method to '%s'" % method)
            if not config:
                config = TrainerConfig()
            elif isinstance(config, dict):
                config = TrainerConfig(config)
            config.method = method

        super(GeneralNeuralTrainer, self).__init__(network, config, annealer=annealer, validator=validator)

        self._learning_func = None
예제 #3
0
파일: trainers.py 프로젝트: JunjieHu/deepy
    def __init__(self, network, config=None, method=None):

        if method:
            logging.info("changing optimization method to '%s'" % method)
            if not config:
                config = TrainerConfig()
            elif isinstance(config, dict):
                config = TrainerConfig(config)
            config.method = method

        super(GeneralNeuralTrainer, self).__init__(network, config)

        self._learning_func = None
예제 #4
0
    def __init__(self, network, config=None, method=None):

        if method:
            logging.info("changing optimization method to '%s'" % method)
            if not config:
                config = TrainerConfig()
            config.method = method

        super(GeneralNeuralTrainer, self).__init__(network, config)

        logging.info('compiling %s learning function', self.__class__.__name__)

        network_updates = list(network.updates) + list(network.training_updates)
        learning_updates = list(self.learning_updates())
        update_list = network_updates + learning_updates
        logging.info("network updates: %s" % " ".join(map(str, [x[0] for x in network_updates])))
        logging.info("learning updates: %s" % " ".join(map(str, [x[0] for x in learning_updates])))

        self.learning_func = theano.function(
            network.input_variables + network.target_variables,
            self.training_variables,
            updates=update_list, allow_input_downcast=True,
            mode=config.get("theano_mode", theano.Mode(linker=THEANO_LINKER)))
예제 #5
0
    def __init__(self, network, config=None, method=None):

        if method:
            logging.info("changing optimization method to '%s'" % method)
            if not config:
                config = TrainerConfig()
            elif isinstance(config, dict):
                config = TrainerConfig(config)
            config.method = method

        super(GeneralNeuralTrainer, self).__init__(network, config)

        logging.info('compiling %s learning function', self.__class__.__name__)

        network_updates = list(network.updates) + list(
            network.training_updates)
        learning_updates = list(self.learning_updates())
        update_list = network_updates + learning_updates
        logging.info("network updates: %s" %
                     " ".join(map(str, [x[0] for x in network_updates])))
        logging.info("learning updates: %s" %
                     " ".join(map(str, [x[0] for x in learning_updates])))

        if False and config.data_transmitter:
            variables = [config.data_transmitter.get_iterator()]
            givens = config.data_transmitter.get_givens()
        else:
            variables = network.input_variables + network.target_variables
            givens = None

        self.learning_func = theano.function(
            variables,
            map(lambda v: theano.Out(v, borrow=True), self.training_variables),
            updates=update_list,
            allow_input_downcast=True,
            mode=self.config.get("theano_mode", None),
            givens=givens)