def train(self,
              epochs,
              steps_per_epoch,
              initial_epoch=0,
              end_of_epoch_callback=None,
              verbose=1):

        epoch = initial_epoch

        logger = ProgbarLogger(count_mode='steps')
        logger.set_params({
            'epochs': epochs,
            'steps': steps_per_epoch,
            'verbose': verbose,
            'metrics': self.metric_names
        })
        logger.on_train_begin()

        while epoch < epochs:
            step = 0
            batch = 0

            logger.on_epoch_begin(epoch)

            while step < steps_per_epoch:

                self.batch_logs['batch'] = batch
                logger.on_batch_begin(batch, self.batch_logs)

                for i in range(len(self.models)):
                    x, y = self.output_generators[i].__next__()

                    if x is None:
                        x, y = self.output_generators[i].__next__()

                    outs = self.models[i].train_on_batch(x, y)

                    if not isinstance(outs, list):
                        outs = [outs]
                    if self.print_full_losses:
                        for l, o in zip(self.metric_names, outs):
                            self.batch_logs[l] = o
                    else:
                        self.batch_logs[self.metric_names[i]] = outs[0]

                logger.on_batch_end(batch, self.batch_logs)

                step += 1
                batch += 1

            logger.on_epoch_end(epoch)
            if end_of_epoch_callback is not None:
                end_of_epoch_callback(epoch)

            epoch += 1
Exemplo n.º 2
0
 def on_epoch_end(self, epoch, logs=None):
     if 'roc_auc' not in self.params['metrics']:
         # self.params['metrics'].append('roc_auc')
         self.params['metrics'].append('val_roc_auc')
     # X_train, Y_train, _, _ = self.model.training_data
     X_test = list(self.model.validation_data[0:-3])
     Y_test = self.model.validation_data[-3]
     # train_pred = self.model.predict_proba(X_train, verbose=0)
     val_pred = self.model.predict_proba(X_test, verbose=0)
     # logs['roc_auc'] = roc_auc_score(Y_train[:,1], train_pred[:,1])
     logs['val_roc_auc'] = roc_auc_score(Y_test[:, 1], val_pred[:, 1])
     ProgbarLogger.on_epoch_end(self, epoch, logs)