def train(self,
              epochs,
              steps_per_epoch,
              initial_epoch=0,
              end_of_epoch_callback=None,
              verbose=1):

        epoch = initial_epoch

        logger = ProgbarLogger(count_mode='steps')
        logger.set_params({
            'epochs': epochs,
            'steps': steps_per_epoch,
            'verbose': verbose,
            'metrics': self.metric_names
        })
        logger.on_train_begin()

        while epoch < epochs:
            step = 0
            batch = 0

            logger.on_epoch_begin(epoch)

            while step < steps_per_epoch:

                self.batch_logs['batch'] = batch
                logger.on_batch_begin(batch, self.batch_logs)

                for i in range(len(self.models)):
                    x, y = self.output_generators[i].__next__()

                    if x is None:
                        x, y = self.output_generators[i].__next__()

                    outs = self.models[i].train_on_batch(x, y)

                    if not isinstance(outs, list):
                        outs = [outs]
                    if self.print_full_losses:
                        for l, o in zip(self.metric_names, outs):
                            self.batch_logs[l] = o
                    else:
                        self.batch_logs[self.metric_names[i]] = outs[0]

                logger.on_batch_end(batch, self.batch_logs)

                step += 1
                batch += 1

            logger.on_epoch_end(epoch)
            if end_of_epoch_callback is not None:
                end_of_epoch_callback(epoch)

            epoch += 1
Ejemplo n.º 2
0
 def on_train_begin(self, logs=None):
     ProgbarLogger.on_train_begin(self, logs)
     self.verbose = self.verbose2