def __init__(self, model, trn_data, trn_target, val_data, val_target): """ :param model: the model to be trained :param trn_data: training inputs and (possibly) training targets :param trn_target: theano variable representing the training target :param val_data: validation inputs and (possibly) validation targets :param val_target: theano variable representing the validation target """ # prepare training data self.n_trn_data = self._get_n_data(trn_data) self.trn_data = [ theano.shared(x.astype(dtype), borrow=True) for x in trn_data ] # compile theano function for a single training update self.trn_inputs = [model.input] if trn_target is None else [ model.input, trn_target ] self.make_update = None # to be implemented by a subclass # if model uses batch norm, compile a theano function for setting up stats if getattr(model, 'batch_norm', False): self.batch_norm_givens = [(bn.m, bn.bm) for bn in model.bns ] + [(bn.v, bn.bv) for bn in model.bns] self.set_batch_norm_stats = theano.function( inputs=[], givens=zip(self.trn_inputs, self.trn_data), updates=[(bn.bm, bn.m) for bn in model.bns] + [(bn.bv, bn.v) for bn in model.bns]) else: self.batch_norm_givens = [] self.set_batch_norm_stats = None # if validation data is given, then set up validation too self.do_validation = val_data is not None if self.do_validation: # prepare validation data self.n_val_data = self._get_n_data(val_data) self.val_data = [ theano.shared(x.astype(dtype), borrow=True) for x in val_data ] # compile theano function for validation self.val_inputs = [model.input] if val_target is None else [ model.input, val_target ] self.validate = None # to be implemented by a subclass # create checkpointer to store best model self.checkpointer = ModelCheckpointer(model) self.best_val_loss = float('inf') # initialize some variables self.trn_loss = float('inf') self.idx_stream = ds.IndexSubSampler(self.n_trn_data, rng=np.random.RandomState(42))
def __init__(self, model, trn_data, trn_loss, trn_target=None, val_data=None, val_loss=None, val_target=None, step=ss.Adam()): """ Constructs and configures the trainer. :param model: the model to be trained :param trn_data: train inputs and (possibly) train targets :param trn_loss: theano variable representing the train loss to minimize :param trn_target: theano variable representing the train target :param val_data: validation inputs and (possibly) validation targets :param val_loss: theano variable representing the validation loss :param val_target: theano variable representing the validation target :param step: step size strategy object :return: None """ # parse input # TODO: it would be good to type check the other inputs too assert isinstance( step, ss.StepStrategy), 'Step must be a step strategy object.' # prepare train data n_trn_data_list = set([x.shape[0] for x in trn_data]) assert len( n_trn_data_list) == 1, 'Number of train data is not consistent.' self.n_trn_data = list(n_trn_data_list)[0] trn_data = [ theano.shared(x.astype(dtype), borrow=True) for x in trn_data ] #! privatise this # compile theano function for a single training update grads = tt.grad(trn_loss, model.parms) idx = tt.ivector('idx') trn_inputs = [model.input ] if trn_target is None else [model.input, trn_target] self.make_update = theano.function( inputs=[idx], outputs=trn_loss, givens=list(zip(trn_inputs, [x[idx] for x in trn_data])), updates=step.updates(model.parms, grads)) # private version # self.make_update = theano.function( # inputs=[idx], # outputs=trn_loss, # givens=list(zip(trn_inputs, [x[idx] for x in trn_data])), # updates=step.updates(model.parms, grads) # ) # if model uses batch norm, compile a theano function for setting up stats if getattr(model, 'batch_norm', False): batch_norm_givens = [(bn.m, bn.bm) for bn in model.bns] + [(bn.v, bn.bv) for bn in model.bns] self.set_batch_norm_stats = theano.function( inputs=[], givens=list(zip(trn_inputs, trn_data)), updates=[(bn.bm, bn.m) for bn in model.bns] + [(bn.bv, bn.v) for bn in model.bns]) else: self.set_batch_norm_stats = None batch_norm_givens = [] # if validation data is given, then set up validation too self.do_validation = val_data is not None if self.do_validation: # prepare validation data n_val_data_list = set([x.shape[0] for x in val_data]) assert len(n_val_data_list ) == 1, 'Number of validation data is not consistent.' self.n_val_data = list(n_val_data_list)[0] val_data = [ theano.shared(x.astype(dtype), borrow=True) for x in val_data ] # compile theano function for validation val_inputs = [model.input] if val_target is None else [ model.input, val_target ] self.validate = theano.function( inputs=[], outputs=val_loss, givens=list(zip(val_inputs, val_data)) + batch_norm_givens) # create checkpointer to store best model self.checkpointer = ModelCheckpointer(model) self.best_val_loss = float('inf') # initialize some variables self.trn_loss = float('inf') self.idx_stream = ds.IndexSubSampler(self.n_trn_data)