示例#1
0
    def __init__(self, model, trn_data, trn_target, val_data, val_target):
        """
        :param model: the model to be trained
        :param trn_data: training inputs and (possibly) training targets
        :param trn_target: theano variable representing the training target
        :param val_data: validation inputs and (possibly) validation targets
        :param val_target: theano variable representing the validation target
        """

        # prepare training data
        self.n_trn_data = self._get_n_data(trn_data)
        self.trn_data = [
            theano.shared(x.astype(dtype), borrow=True) for x in trn_data
        ]

        # compile theano function for a single training update
        self.trn_inputs = [model.input] if trn_target is None else [
            model.input, trn_target
        ]
        self.make_update = None  # to be implemented by a subclass

        # if model uses batch norm, compile a theano function for setting up stats
        if getattr(model, 'batch_norm', False):
            self.batch_norm_givens = [(bn.m, bn.bm) for bn in model.bns
                                      ] + [(bn.v, bn.bv) for bn in model.bns]
            self.set_batch_norm_stats = theano.function(
                inputs=[],
                givens=zip(self.trn_inputs, self.trn_data),
                updates=[(bn.bm, bn.m)
                         for bn in model.bns] + [(bn.bv, bn.v)
                                                 for bn in model.bns])
        else:
            self.batch_norm_givens = []
            self.set_batch_norm_stats = None

        # if validation data is given, then set up validation too
        self.do_validation = val_data is not None

        if self.do_validation:

            # prepare validation data
            self.n_val_data = self._get_n_data(val_data)
            self.val_data = [
                theano.shared(x.astype(dtype), borrow=True) for x in val_data
            ]

            # compile theano function for validation
            self.val_inputs = [model.input] if val_target is None else [
                model.input, val_target
            ]
            self.validate = None  # to be implemented by a subclass

            # create checkpointer to store best model
            self.checkpointer = ModelCheckpointer(model)
            self.best_val_loss = float('inf')

        # initialize some variables
        self.trn_loss = float('inf')
        self.idx_stream = ds.IndexSubSampler(self.n_trn_data,
                                             rng=np.random.RandomState(42))
示例#2
0
    def __init__(self,
                 model,
                 trn_data,
                 trn_loss,
                 trn_target=None,
                 val_data=None,
                 val_loss=None,
                 val_target=None,
                 step=ss.Adam()):
        """
        Constructs and configures the trainer.
        :param model: the model to be trained
        :param trn_data: train inputs and (possibly) train targets
        :param trn_loss: theano variable representing the train loss to minimize
        :param trn_target: theano variable representing the train target
        :param val_data: validation inputs and (possibly) validation targets
        :param val_loss: theano variable representing the validation loss
        :param val_target: theano variable representing the validation target
        :param step: step size strategy object
        :return: None
        """

        # parse input
        # TODO: it would be good to type check the other inputs too
        assert isinstance(
            step, ss.StepStrategy), 'Step must be a step strategy object.'

        # prepare train data
        n_trn_data_list = set([x.shape[0] for x in trn_data])
        assert len(
            n_trn_data_list) == 1, 'Number of train data is not consistent.'
        self.n_trn_data = list(n_trn_data_list)[0]
        trn_data = [
            theano.shared(x.astype(dtype), borrow=True) for x in trn_data
        ]

        #! privatise this
        # compile theano function for a single training update
        grads = tt.grad(trn_loss, model.parms)
        idx = tt.ivector('idx')
        trn_inputs = [model.input
                      ] if trn_target is None else [model.input, trn_target]
        self.make_update = theano.function(
            inputs=[idx],
            outputs=trn_loss,
            givens=list(zip(trn_inputs, [x[idx] for x in trn_data])),
            updates=step.updates(model.parms, grads))

        # private version
        # self.make_update = theano.function(
        #     inputs=[idx],
        #     outputs=trn_loss,
        #     givens=list(zip(trn_inputs, [x[idx] for x in trn_data])),
        #     updates=step.updates(model.parms, grads)
        # )

        # if model uses batch norm, compile a theano function for setting up stats
        if getattr(model, 'batch_norm', False):
            batch_norm_givens = [(bn.m, bn.bm)
                                 for bn in model.bns] + [(bn.v, bn.bv)
                                                         for bn in model.bns]
            self.set_batch_norm_stats = theano.function(
                inputs=[],
                givens=list(zip(trn_inputs, trn_data)),
                updates=[(bn.bm, bn.m)
                         for bn in model.bns] + [(bn.bv, bn.v)
                                                 for bn in model.bns])
        else:
            self.set_batch_norm_stats = None
            batch_norm_givens = []

        # if validation data is given, then set up validation too
        self.do_validation = val_data is not None

        if self.do_validation:

            # prepare validation data
            n_val_data_list = set([x.shape[0] for x in val_data])
            assert len(n_val_data_list
                       ) == 1, 'Number of validation data is not consistent.'
            self.n_val_data = list(n_val_data_list)[0]
            val_data = [
                theano.shared(x.astype(dtype), borrow=True) for x in val_data
            ]

            # compile theano function for validation
            val_inputs = [model.input] if val_target is None else [
                model.input, val_target
            ]
            self.validate = theano.function(
                inputs=[],
                outputs=val_loss,
                givens=list(zip(val_inputs, val_data)) + batch_norm_givens)

            # create checkpointer to store best model
            self.checkpointer = ModelCheckpointer(model)
            self.best_val_loss = float('inf')

        # initialize some variables
        self.trn_loss = float('inf')
        self.idx_stream = ds.IndexSubSampler(self.n_trn_data)