예제 #1
0
def main(dataset, batch_size, patch_size, epochs, label_smoothing,
         label_flipping):
    print(project_dir)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
    sess = tf.Session(config=config)
    K.tensorflow_backend.set_session(
        sess)  # set this TensorFlow session as the default session for Keras

    image_data_format = "channels_first"
    K.set_image_data_format(image_data_format)

    save_images_every_n_batches = 30
    save_model_every_n_epochs = 0

    # configuration parameters
    print("Config params:")
    print("  dataset = {}".format(dataset))
    print("  batch_size = {}".format(batch_size))
    print("  patch_size = {}".format(patch_size))
    print("  epochs = {}".format(epochs))
    print("  label_smoothing = {}".format(label_smoothing))
    print("  label_flipping = {}".format(label_flipping))
    print("  save_images_every_n_batches = {}".format(
        save_images_every_n_batches))
    print("  save_model_every_n_epochs = {}".format(save_model_every_n_epochs))

    model_name = datetime.strftime(datetime.now(), '%y%m%d-%H%M')
    model_dir = os.path.join(project_dir, "models", model_name)
    fig_dir = os.path.join(project_dir, "reports", "figures")
    logs_dir = os.path.join(project_dir, "reports", "logs", model_name)

    os.makedirs(model_dir)

    # Load and rescale data
    ds_train_gen = data_utils.DataGenerator(file_path=dataset,
                                            dataset_type="train",
                                            batch_size=batch_size)
    ds_train_disc = data_utils.DataGenerator(file_path=dataset,
                                             dataset_type="train",
                                             batch_size=batch_size)
    ds_val = data_utils.DataGenerator(file_path=dataset,
                                      dataset_type="val",
                                      batch_size=batch_size)
    enq_train_gen = OrderedEnqueuer(ds_train_gen,
                                    use_multiprocessing=True,
                                    shuffle=True)
    enq_train_disc = OrderedEnqueuer(ds_train_disc,
                                     use_multiprocessing=True,
                                     shuffle=True)
    enq_val = OrderedEnqueuer(ds_val, use_multiprocessing=True, shuffle=False)

    img_dim = ds_train_gen[0][0].shape[-3:]

    n_batch_per_epoch = len(ds_train_gen)
    epoch_size = n_batch_per_epoch * batch_size

    print("Derived params:")
    print("  n_batch_per_epoch = {}".format(n_batch_per_epoch))
    print("  epoch_size = {}".format(epoch_size))
    print("  n_batches_val = {}".format(len(ds_val)))

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size)

    tensorboard = TensorBoard(log_dir=logs_dir,
                              histogram_freq=0,
                              batch_size=batch_size,
                              write_graph=True,
                              write_grads=True,
                              update_freq='batch')

    try:
        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3,
                                 beta_1=0.9,
                                 beta_2=0.999,
                                 epsilon=1e-08)

        # Load generator model
        generator_model = models.generator_unet_upsampling(img_dim)
        generator_model.summary()
        plot_model(generator_model,
                   to_file=os.path.join(fig_dir, "generator_model.png"),
                   show_shapes=True,
                   show_layer_names=True)

        # Load discriminator model
        # TODO: modify disc to accept real input as well
        discriminator_model = models.DCGAN_discriminator(
            img_dim_disc, nb_patch)
        discriminator_model.summary()
        plot_model(discriminator_model,
                   to_file=os.path.join(fig_dir, "discriminator_model.png"),
                   show_shapes=True,
                   show_layer_names=True)

        # TODO: pretty sure this is unnecessary
        generator_model.compile(loss='mae', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   img_dim, patch_size, image_data_format)

        # L1 loss applies to generated image, cross entropy applies to predicted label
        loss = [models.l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss,
                            loss_weights=loss_weights,
                            optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=opt_discriminator)

        tensorboard.set_model(DCGAN_model)

        # Start training
        enq_train_gen.start(workers=1, max_queue_size=20)
        enq_train_disc.start(workers=1, max_queue_size=20)
        enq_val.start(workers=1, max_queue_size=20)
        out_train_gen = enq_train_gen.get()
        out_train_disc = enq_train_disc.get()
        out_val = enq_val.get()

        print("Start training")
        for e in range(1, epochs + 1):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            start = time.time()

            for batch_counter in range(1, n_batch_per_epoch + 1):
                X_transformed_batch, X_orig_batch = next(out_train_disc)

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(
                    X_transformed_batch,
                    X_orig_batch,
                    generator_model,
                    batch_counter,
                    patch_size,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(out_train_gen)
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                # Set labels to 1 (real) to maximize the discriminator loss
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                metrics = [("D logloss", disc_loss), ("G tot", gen_loss[0]),
                           ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]
                progbar.add(batch_size, values=metrics)

                logs = {k: v for (k, v) in metrics}
                logs["size"] = batch_size

                tensorboard.on_batch_end(batch_counter, logs=logs)

                # Save images for visualization
                if batch_counter % save_images_every_n_batches == 0:
                    # Get new images from validation
                    data_utils.plot_generated_batch(
                        X_transformed_batch, X_orig_batch, generator_model,
                        os.path.join(logs_dir, "current_batch_training.png"))
                    X_transformed_batch, X_orig_batch = next(out_val)
                    data_utils.plot_generated_batch(
                        X_transformed_batch, X_orig_batch, generator_model,
                        os.path.join(logs_dir, "current_batch_validation.png"))

            print("")
            print('Epoch %s/%s, Time: %s' % (e, epochs, time.time() - start))
            tensorboard.on_epoch_end(e, logs=logs)

            if (save_model_every_n_epochs >= 1 and e % save_model_every_n_epochs == 0) or \
                    (e == epochs):
                print("Saving model for epoch {}...".format(e), end="")
                sys.stdout.flush()
                gen_weights_path = os.path.join(
                    model_dir, 'gen_weights_epoch{:03d}.h5'.format(e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    model_dir, 'disc_weights_epoch{:03d}.h5'.format(e))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    model_dir, 'DCGAN_weights_epoch{:03d}.h5'.format(e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)
                print("done")

    except KeyboardInterrupt:
        pass

    enq_train_gen.stop()
    enq_train_disc.stop()
    enq_val.stop()
예제 #2
0
class NLITaskTrain(object):
    def __init__(self,
                 model,
                 train_data,
                 test_data,
                 dev_data=None,
                 optimizer=None,
                 log_dir=None,
                 save_dir=None,
                 name=None):
        self.model = model
        self.name = name
        """Data"""
        self.train_label = train_data[-1]
        self.train_data = train_data[:-1]
        self.test_data = test_data
        self.dev_data = dev_data
        if self.dev_data is not None:
            self.dev_label = self.dev_data[-1]
            self.dev_data = self.dev_data[:-1]
        """Train Methods"""
        self.optimizer = optimizer
        self.current_optimizer = None
        self.current_optimizer_id = -1
        self.current_switch_steps = -1
        """Others"""
        self.log_dir = log_dir
        if self.log_dir is not None and not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        self.logger = TensorBoard(log_dir=self.log_dir)

        self.save_dir = save_dir
        if self.save_dir is not None and not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

    def train(self, batch_size=128, eval_interval=512, shuffle=True):
        return

    def train_multi_optimizer(self,
                              batch_size=128,
                              eval_interval=512,
                              shuffle=True):
        assert isinstance(self.optimizer, Iterable) is True
        assert len(self.optimizer) > 1

        self.current_optimizer = None
        self.current_optimizer_id = -1
        self.current_switch_steps = -1

        self.init_optimizer()
        self.model.summary()

        train_steps, no_progress_steps, epoch = 0, 0, 0
        train_batch_start = 0
        best_loss = np.inf

        while True:
            if shuffle:
                random_index = np.random.permutation(len(self.train_label))
                self.train_data = [
                    data[random_index] for data in self.train_data
                ]
                self.train_label = self.train_label[random_index]

            dev_loss, dev_acc = self.evaluate(batch_size=batch_size)
            self.logger.on_epoch_end(epoch=epoch,
                                     logs={
                                         "dev_loss": dev_loss,
                                         "dev_acc": dev_acc
                                     })
            self.model.save(
                self.save_dir +
                "epoch{}-loss{}-acc{}.model".format(epoch, dev_loss, dev_acc))
            epoch += 1
            no_progress_steps += 1

            if dev_loss < best_loss:
                best_loss = dev_loss
                no_progress_steps = 0

            if no_progress_steps > self.current_switch_steps:
                self.switch_optimizer()
                no_progress_steps = 0

            for i in range(eval_interval):
                train_loss, train_acc = self.model.train_on_batch([
                    data[train_batch_start:train_batch_start + batch_size]
                    for data in self.train_data
                ], self.train_label[train_batch_start:train_batch_start +
                                    batch_size])
                self.logger.on_batch_end(train_steps,
                                         logs={
                                             "train_loss": train_loss,
                                             "train_acc": train_acc
                                         })

                train_steps += 1
                train_batch_start += batch_size
                if train_batch_start > len(self.train_label):
                    train_batch_start = 0
                    if shuffle:
                        random_index = np.random.permutation(
                            len(self.train_label))
                        self.train_data = [
                            data[random_index] for data in self.train_data
                        ]
                        self.train_label = self.train_label[random_index]

    def init_optimizer(self):
        self.current_optimizer_id = 0
        self.current_optimizer, self.current_switch_steps = self.optimizer[
            self.current_optimizer_id]
        self.model.compile(optimizer=self.current_optimizer,
                           loss="binary_crossentropy",
                           metrics=["acc"])
        self.logger.set_model(self.model)
        logger.info("Switch to {} optimizer".format(self.current_optimizer))

    def evaluate(self, X=None, y=None, batch_size=None):
        if X is None:
            X, y = self.dev_data, self.dev_label

        loss, acc = self.model.evaluate(X, y, batch_size=batch_size)
        return loss, acc

    def switch_optimizer(self):
        self.current_optimizer_id += 1
        if self.current_optimizer_id >= len(self.optimizer):
            logger.info("Training processes finished")
            exit(0)

        self.current_optimizer, self.current_switch_steps = self.optimizer[
            self.current_optimizer_id]
        self.model.compile(optimizer=self.current_optimizer,
                           loss="binary_crossentropy",
                           metrics=["acc"])
        self.logger.set_model(self.model)
        logger.info("Switch to {} optimizer".format(self.current_optimizer))
예제 #3
0
def main(not_parsed_args):
    logging.info('Build dataset')
    train_set = get_training_set(FLAGS.dataset_h, FLAGS.dataset_l,
                                 FLAGS.frames, FLAGS.scale, True,
                                 'filelist.txt', True, FLAGS.patch_size,
                                 FLAGS.future_frame)
    if FLAGS.dataset_val:
        val_set = get_eval_set(FLAGS.dataset_val_h, FLAGS.dataset_val_l,
                               FLAGS.frames, FLAGS.scale, True, 'filelist.txt',
                               True, FLAGS.patch_size, FLAGS.future_frame)

    logging.info('Build model')
    model = RBPN()
    model.summary()
    last_epoch, last_step = load_weights(model)
    model.compile(optimizer=optimizers.Adam(FLAGS.lr),
                  loss=losses.mae,
                  metrics=[psnr])

    # checkpoint = ModelCheckpoint('models/model.hdf5', verbose=1)
    tensorboard = TensorBoard(log_dir='./tf_logs',
                              batch_size=FLAGS.batch_size,
                              write_graph=False,
                              write_grads=True,
                              write_images=True,
                              update_freq='batch')
    tensorboard.set_model(model)

    logging.info('Training start')
    for e in range(last_epoch, FLAGS.epochs):
        tensorboard.on_epoch_begin(e)
        for s in range(last_step + 1, len(train_set) // FLAGS.batch_size):
            tensorboard.on_batch_begin(s)
            x, y = train_set.batch(FLAGS.batch_size)
            loss = model.train_on_batch(x, y)
            print('Epoch %d step %d, loss %f psnr %f' %
                  (e, s, loss[0], loss[1]))
            tensorboard.on_batch_end(s, named_logs(model, loss, s))

            if FLAGS.dataset_val and s > 0 and s % FLAGS.val_interval == 0 or s == len(
                    train_set) // FLAGS.batch_size - 1:
                logging.info('Validation start')
                val_loss = 0
                val_psnr = 0
                for j in range(len(val_set)):
                    x_val, y_val = val_set.batch(1)
                    score = model.test_on_batch(x_val, y_val)
                    val_loss += score[0]
                    val_psnr += score[1]
                val_loss /= len(val_set)
                val_psnr /= len(val_set)
                logging.info('Validation average loss %f psnr %f' %
                             (val_loss, val_psnr))

            if s > 0 and s % FLAGS.save_interval == 0 or s == len(
                    train_set) // FLAGS.batch_size - 1:
                logging.info('Saving model')
                filename = 'model_%d_%d.h5' % (e, s)
                path = os.path.join(FLAGS.model_dir, filename)
                path_info = os.path.join(FLAGS.model_dir, 'info')
                model.save_weights(path)
                f = open(path_info, 'w')
                f.write(filename)
                f.close()
        tensorboard.on_epoch_end(e)
        last_step = -1
예제 #4
0
class ExtendedLogger(Callback):

    val_data_metrics = {}

    def __init__(self,
                 prediction_layer,
                 output_dir='./tmp',
                 stateful=False,
                 stateful_reset_interval=None,
                 starting_indicies=None):

        if stateful and stateful_reset_interval is None:
            raise ValueError(
                'If model is stateful, then seq-len has to be defined!')

        super(ExtendedLogger, self).__init__()

        self.csv_dir = os.path.join(output_dir, 'csv')
        self.tb_dir = os.path.join(output_dir, 'tensorboard')
        self.pred_dir = os.path.join(output_dir, 'predictions')
        self.plot_dir = os.path.join(output_dir, 'plots')

        make_dir(self.csv_dir)
        make_dir(self.tb_dir)
        make_dir(self.plot_dir)
        make_dir(self.pred_dir)

        self.stateful = stateful
        self.stateful_reset_interval = stateful_reset_interval
        self.starting_indicies = starting_indicies
        self.csv_logger = CSVLogger(os.path.join(self.csv_dir, 'run.csv'))
        self.tensorboard = TensorBoard(log_dir=self.tb_dir, write_graph=True)
        self.prediction_layer = prediction_layer

    def set_params(self, params):
        super(ExtendedLogger, self).set_params(params)
        self.tensorboard.set_params(params)
        self.tensorboard.batch_size = params['batch_size']
        self.csv_logger.set_params(params)

    def set_model(self, model):
        super(ExtendedLogger, self).set_model(model)
        self.tensorboard.set_model(model)
        self.csv_logger.set_model(model)

    def on_batch_begin(self, batch, logs=None):
        self.csv_logger.on_batch_begin(batch, logs=logs)
        self.tensorboard.on_batch_begin(batch, logs=logs)

    def on_batch_end(self, batch, logs=None):
        self.csv_logger.on_batch_end(batch, logs=logs)
        self.tensorboard.on_batch_end(batch, logs=logs)

    def on_train_begin(self, logs=None):
        self.csv_logger.on_train_begin(logs=logs)
        self.tensorboard.on_train_begin(logs=logs)

    def on_train_end(self, logs=None):
        self.csv_logger.on_train_end(logs=logs)
        self.tensorboard.on_train_end(logs)

    def on_epoch_begin(self, epoch, logs=None):
        self.csv_logger.on_epoch_begin(epoch, logs=logs)
        self.tensorboard.on_epoch_begin(epoch, logs=logs)

    def on_epoch_end(self, epoch, logs=None):

        with timeit('metrics'):

            outputs = self.model.get_layer(self.prediction_layer).output
            self.prediction_model = Model(inputs=self.model.input,
                                          outputs=outputs)

            batch_size = self.params['batch_size']

            if isinstance(self.validation_data[-1], float):
                val_data = self.validation_data[:-2]
            else:
                val_data = self.validation_data[:-1]

            y_true = val_data[1]

            callback = None
            if self.stateful:
                callback = ResetStatesCallback(
                    interval=self.stateful_reset_interval)
                callback.model = self.prediction_model

            y_pred = self.prediction_model.predict(val_data[:-1],
                                                   batch_size=batch_size,
                                                   verbose=1,
                                                   callback=callback)

            print(y_true.shape, y_pred.shape)

            self.write_prediction(epoch, y_true, y_pred)

            y_true = y_true.reshape((-1, 7))
            y_pred = y_pred.reshape((-1, 7))

            self.save_error_histograms(epoch, y_true, y_pred)
            self.save_topview_trajectories(epoch, y_true, y_pred)

            new_logs = {
                name: np.array(metric(y_true, y_pred))
                for name, metric in self.val_data_metrics.items()
            }
            logs.update(new_logs)

            homo_logs = self.try_add_homoscedastic_params()
            logs.update(homo_logs)

            self.tensorboard.validation_data = self.validation_data
            self.csv_logger.validation_data = self.validation_data

            self.tensorboard.on_epoch_end(epoch, logs=logs)
            self.csv_logger.on_epoch_end(epoch, logs=logs)

    def add_validation_metrics(self, metrics_dict):
        self.val_data_metrics.update(metrics_dict)

    def add_validation_metric(self, name, metric):
        self.val_data_metrics[name] = metric

    def try_add_homoscedastic_params(self):
        homo_pos_loss_layer = search_layer(self.model, 'homo_pos_loss')
        homo_quat_loss_layer = search_layer(self.model, 'homo_quat_loss')

        if homo_pos_loss_layer:
            homo_pos_log_vars = np.array(homo_pos_loss_layer.get_weights()[0])
            homo_quat_log_vars = np.array(
                homo_quat_loss_layer.get_weights()[0])
            return {
                'pos_log_var': np.array(homo_pos_log_vars),
                'quat_log_var': np.array(homo_quat_log_vars),
            }
        else:
            return {}

    def write_prediction(self, epoch, y_true, y_pred):
        filename = '{:04d}_predictions.npy'.format(epoch)
        filename = os.path.join(self.pred_dir, filename)
        arr = {'y_pred': y_pred, 'y_true': y_true}
        np.save(filename, arr)

    def save_topview_trajectories(self,
                                  epoch,
                                  y_true,
                                  y_pred,
                                  max_segment=1000):

        if self.starting_indicies is None:
            self.starting_indicies = {'valid': range(0, 4000, 1000) + [4000]}

        for begin, end in pairwise(self.starting_indicies['valid']):

            diff = end - begin
            if diff > max_segment:
                subindicies = range(begin, end, max_segment) + [end]
                for b, e in pairwise(subindicies):
                    self.save_trajectory(epoch, y_true, y_pred, b, e)

            self.save_trajectory(epoch, y_true, y_pred, begin, end)

    def save_trajectory(self, epoch, y_true, y_pred, begin, end):
        true_xy, pred_xy = y_true[begin:end, :2], y_pred[begin:end, :2]

        true_q = quaternion.as_quat_array(y_true[begin:end, [6, 3, 4, 5]])
        true_q = quaternion.as_euler_angles(true_q)[1]

        pred_q = quaternion.as_quat_array(y_pred[begin:end, [6, 3, 4, 5]])
        pred_q = quaternion.as_euler_angles(pred_q)[1]

        plt.clf()

        plt.plot(true_xy[:, 0], true_xy[:, 1], 'g-')
        plt.plot(pred_xy[:, 0], pred_xy[:, 1], 'r-')

        for ((x1, y1), (x2, y2)) in zip(true_xy, pred_xy):
            plt.plot([x1, x2], [y1, y2],
                     color='k',
                     linestyle='-',
                     linewidth=0.3,
                     alpha=0.2)

        plt.grid(True)
        plt.xlabel('x [m]')
        plt.ylabel('y [m]')
        plt.title('Top-down view of trajectory')
        plt.axis('equal')

        x_range = (np.min(true_xy[:, 0]) - .2, np.max(true_xy[:, 0]) + .2)
        y_range = (np.min(true_xy[:, 1]) - .2, np.max(true_xy[:, 1]) + .2)

        plt.xlim(x_range)
        plt.ylim(y_range)

        filename = 'epoch={epoch:04d}_begin={begin:04d}_end={end:04d}_trajectory.pdf' \
          .format(epoch=epoch, begin=begin, end=end)
        filename = os.path.join(self.plot_dir, filename)
        plt.savefig(filename)

    def save_error_histograms(self, epoch, y_true, y_pred):
        pos_errors = PoseMetrics.abs_errors_position(y_true, y_pred)
        pos_errors = np.sort(pos_errors)

        angle_errors = PoseMetrics.abs_errors_orienation(y_true, y_pred)
        angle_errors = np.sort(angle_errors)

        size = len(y_true)
        ys = np.arange(size) / float(size)

        plt.clf()

        plt.subplot(2, 1, 1)
        plt.title('Empirical CDF of absolute errors')
        plt.grid(True)
        plt.plot(pos_errors, ys, 'k-')
        plt.xlabel('Absolute Position Error (m)')
        plt.xlim(0, 1.2)

        plt.subplot(2, 1, 2)
        plt.grid(True)
        plt.plot(angle_errors, ys, 'r-')
        plt.xlabel('Absolute Angle Error (deg)')
        plt.xlim(0, 70)

        filename = '{:04d}_cdf.pdf'.format(epoch)
        filename = os.path.join(self.plot_dir, filename)
        plt.savefig(filename)
예제 #5
0
파일: model.py 프로젝트: nptdat/qanet
    def train(self, batch_size=4, epochs=25):
        cf = self.cf
        self.compile()
        model = self.keras_model
        word_vectors, char_vectors, train_ques_ids, X_train, y_train, val_ques_ids, X_valid, y_valid = self.data_train

        qanet_cb = QANetCallback(decay=cf.EMA_DECAY)
        tb = TensorBoard(log_dir=cf.TENSORBOARD_PATH,
                         histogram_freq=0,
                         write_graph=False,
                         write_images=False,
                         update_freq=cf.TENSORBOARD_UPDATE_FREQ)

        # Call set_model for all callbacks
        qanet_cb.set_model(model)
        tb.set_model(model)

        ep_list = []
        avg_train_loss_list = []
        em_score_list = []
        f1_score_list = []

        global_steps = 0
        gt_start_list, gt_end_list = y_valid[2:]
        for ep in range(1, epochs + 1):  # Epoch num start from 1
            print('----------- Training for epoch {}...'.format(ep))
            # Train
            batch = 0
            sum_loss = 0
            num_batches = (len(X_train[0]) - 1) // batch_size + 1
            for X_batch, y_batch in get_batch(X_train,
                                              y_train,
                                              batch_size=batch_size,
                                              shuffle=True):
                batch_logs = {'batch': batch, 'size': len(X_batch[0])}
                tb.on_batch_begin(batch, batch_logs)

                loss, loss_p1, loss_p2, loss_start, loss_end = model.train_on_batch(
                    X_batch, y_batch)
                sum_loss += loss
                avg_loss = sum_loss / (batch + 1)
                print(
                    'Epoch: {}/{}, Batch: {}/{}, Accumulative average loss: {:.4f}, Loss: {:.4f}, Loss_P1: {:.4f}, Loss_P2: {:.4f}, Loss_start: {:.4f}, Loss_end: {:.4f}'
                    .format(ep, epochs, batch, num_batches, avg_loss, loss,
                            loss_p1, loss_p2, loss_start, loss_end))
                batch_logs.update({
                    'loss': loss,
                    'loss_p1': loss_p1,
                    'loss_p2': loss_p2
                })
                qanet_cb.on_batch_end(batch, batch_logs)
                tb.on_batch_end(batch, batch_logs)

                global_steps += 1
                batch += 1

            ep_list.append(ep)
            avg_train_loss_list.append(avg_loss)

            print('Backing up temp weights...')
            model.save_weights(cf.TEMP_MODEL_PATH)
            qanet_cb.on_epoch_end(ep)  # Apply EMA weights
            model.save_weights(cf.MODEL_PATH % str(ep))

            print('----------- Validating for epoch {}...'.format(ep))
            valid_scores = self.validate(X_valid,
                                         y_valid,
                                         gt_start_list,
                                         gt_end_list,
                                         batch_size=cf.BATCH_SIZE)
            em_score_list.append(valid_scores['exact_match'])
            f1_score_list.append(valid_scores['f1'])
            print(
                '------- Result of epoch: {}/{}, Average_train_loss: {:.6f}, EM: {:.4f}, F1: {:.4f}\n'
                .format(ep, epochs, avg_loss, valid_scores['exact_match'],
                        valid_scores['f1']))

            tb.on_epoch_end(ep, {
                'f1': valid_scores['f1'],
                'em': valid_scores['exact_match']
            })

            # Write result to CSV file
            result = pd.DataFrame({
                'epoch': ep_list,
                'avg_train_loss': avg_train_loss_list,
                'em': em_score_list,
                'f1': f1_score_list
            })
            result.to_csv(cf.RESULT_LOG, index=None)

            # Restore the original weights to continue training
            print('Restoring temp weights...')
            model.load_weights(cf.TEMP_MODEL_PATH)

        tb.on_train_end(None)