Ejemplo n.º 1
0
    def get_rec_image(self, x, s):
        z, _ = self.enc_modality.predict([s, x])
        gaussian = NormalDistribution()

        y = self.reconstructor.predict([s, z])
        y_s0 = self.reconstructor.predict([s, np.zeros(z.shape)])
        all_bkg = np.concatenate([
            np.zeros(s.shape[:-1] + (s.shape[-1] - 1, )),
            np.ones(s.shape[:-1] + (1, ))
        ],
                                 axis=-1)
        y_0z = self.reconstructor.predict([all_bkg, z])
        y_00 = self.reconstructor.predict([all_bkg, np.zeros(z.shape)])
        z_random = gaussian.sample(z.shape)
        y_random = self.reconstructor.predict([s, z_random])
        rows = [
            np.concatenate([
                x[i, :, :, 0], y[i, :, :, 0], y_random[i, :, :, 0],
                y_s0[i, :, :, 0]
            ] + [
                self.reconstructor.predict([get_s0chn(k, s), z])[i, :, :, 0]
                for k in range(s.shape[-1] - 1)
            ] + [y_0z[i, :, :, 0], y_00[i, :, :, 0]],
                           axis=1) for i in range(x.shape[0])
        ]
        header = utils.image_utils.makeTextHeaderImage(
            x.shape[2], ['X', 'rec(s,z)', 'rec(s,~z)', 'rec(s,0)'] +
            ['rec(s0_%d, z)' % k
             for k in range(s.shape[-1] - 1)] + ['rec(0, z)', 'rec(0,0)'])
        im_plot = np.concatenate([header] + rows, axis=0)
        im_plot = np.clip(im_plot, -1, 1)
        return im_plot
Ejemplo n.º 2
0
    def prepare_data_to_train(self, x1_pairs, x2_pairs, m1_pairs, m2_pairs):
        if m2_pairs is not None:
            [x1_pairs, x2_pairs, m1_pairs, m2_pairs
             ] = self.align_batches([x1_pairs, x2_pairs, m1_pairs, m2_pairs])
        else:
            [x1_pairs, x2_pairs,
             m1_pairs] = self.align_batches([x1_pairs, x2_pairs, m1_pairs])

        split_images = lambda x: [
            x[..., i:i + 1] for i in range(self.conf.n_pairs)
        ]
        x1_list = split_images(x1_pairs)
        x2_list = split_images(x2_pairs)
        x1 = x1_list[0]
        x2 = x2_list[0]
        m1 = self.add_residual(m1_pairs[..., 0:self.loader.num_masks])
        m2 = self.add_residual(m2_pairs[
            ..., 0:self.loader.num_masks]) if m2_pairs is not None else None

        batch_size = x1.shape[
            0]  # maybe this differs from conf.batch_size at the last batch.
        norm = NormalDistribution()
        z1 = norm.sample((batch_size, self.conf.num_z))
        z2 = norm.sample((batch_size, self.conf.num_z))
        return m1, m2, x1, x1_list, x2, x2_list, z1, z2
Ejemplo n.º 3
0
def vae_sample(args):
    z_mean, z_log_var = args
    batch = z_mean.shape[0]
    dim = z_mean.shape[1]
    # by default, random_normal has mean=0 and std=1.0
    gaussian = NormalDistribution()
    epsilon = gaussian.sample((batch, dim))
    return z_mean + np.exp(0.5 * z_log_var) * epsilon
Ejemplo n.º 4
0
    def plot_reconstructions(self, lb_images, ul_images, epoch):
        """
        Plot two images showing the combination of the spatial and modality LR to generate an image. The first
        image uses the predicted S and Z and the second samples Z from a Gaussian.
        :param lb_images:   a list of 2 4-dim arrays of images + corresponding masks
        :param ul_images:   a list of 4-dim image arrays
        :param epoch:       the epoch number
        """

        # combine labelled and unlabelled images and randomly sample 4 examples
        images = lb_images[0]  # [el[0] for el in lb_images]
        if len(ul_images) > 0:
            images = np.concatenate([images, ul_images], axis=0)
        x = utils.data_utils.sample(images, nb_samples=4)

        # S + Z -> Image
        gaussian = NormalDistribution()

        s = self.enc_anatomy.predict(x)
        z, _ = self.enc_modality.predict([s, x])

        y = self.reconstructor.predict([s, z])
        y_s0 = self.reconstructor.predict([s, np.zeros(z.shape)])
        all_bkg = np.concatenate([
            np.zeros(s.shape[:-1] + (s.shape[-1] - 1, )),
            np.ones(s.shape[:-1] + (1, ))
        ],
                                 axis=-1)
        y_0z = self.reconstructor.predict([all_bkg, z])
        y_00 = self.reconstructor.predict([all_bkg, np.zeros(z.shape)])
        z_random = gaussian.sample(z.shape)
        y_random = self.reconstructor.predict([s, z_random])

        rows = [
            np.concatenate([
                x[i, :, :, 0], y[i, :, :, 0], y_random[i, :, :, 0],
                y_s0[i, :, :, 0]
            ] + [
                self.reconstructor.predict([self._get_s0chn(
                    k, s), z])[i, :, :, 0] for k in range(s.shape[-1] - 1)
            ] + [y_0z[i, :, :, 0], y_00[i, :, :, 0]],
                           axis=1) for i in range(x.shape[0])
        ]
        header = utils.image_utils.makeTextHeaderImage(
            x.shape[2], ['X', 'rec(s,z)', 'rec(s,~z)', 'rec(s,0)'] +
            ['rec(s0_%d, z)' % k
             for k in range(s.shape[-1] - 1)] + ['rec(0, z)', 'rec(0,0)'])
        im_plot = np.concatenate([header] + rows, axis=0)
        im_plot = np.clip(im_plot, -1, 1)
        scipy.misc.imsave(self.rec_folder + '/rec_epoch_%d.png' % epoch,
                          im_plot)

        plt.figure()
        plt.imshow(im_plot, cmap='gray')
        plt.xticks([])
        plt.yticks([])
        plt.tight_layout()
        plt.close()
    def train_batch_generators(self, epoch_loss):
        """
        Train generator/segmentation networks.
        :param epoch_loss:  Dictionary of losses for the epoch
        """
        if self.gen_labelled is not None:
            x, m, scanner = next(self.gen_labelled)
            batch_size = x.shape[
                0]  # maybe this differs from conf.batch_size at the last batch.

            # Train labelled path (G_supervised_model)
            h = self.model.G_supervised_trainer.fit(
                x, [m, x, np.zeros(batch_size)], epochs=1, verbose=0)
            epoch_loss['supervised_Mask'].append(h.history['Segmentor_loss'])
            epoch_loss['rec_X'].append(h.history['Reconstructor_loss'])
            epoch_loss['KL'].append(h.history['Enc_Modality_loss'])

            # Train Z Regressor
            if self.model.Z_Regressor is not None:
                s = self.model.Enc_Anatomy.predict(x)
                sample_z = NormalDistribution().sample(
                    (batch_size, self.conf.num_z))
                h = self.model.Z_Regressor.fit([s, sample_z],
                                               sample_z,
                                               epochs=1,
                                               verbose=0)
                epoch_loss['rec_Z'].append(h.history['loss'])

        # Train unlabelled path
        if self.gen_unlabelled is not None:
            x = next(self.gen_unlabelled)
            batch_size = x.shape[
                0]  # maybe this differs from conf.batch_size at the last batch.

            # Train unlabelled path (G_model)
            h = self.model.G_trainer.fit(x, [
                np.ones((batch_size, ) + self.model.D_Mask.output_shape[1:]),
                x,
                np.zeros(batch_size)
            ],
                                         epochs=1,
                                         verbose=0)
            epoch_loss['adv_M'].append(h.history['D_Mask_loss'])
            epoch_loss['rec_X'].append(h.history['Reconstructor_loss'])
            epoch_loss['KL'].append(h.history['Enc_Modality_loss'])

            # Train Z Regressor
            if self.model.Z_Regressor is not None:
                s = self.model.Enc_Anatomy.predict(x)
                sample_z = NormalDistribution().sample(
                    (batch_size, self.conf.num_z))
                h = self.model.Z_Regressor.fit([s, sample_z],
                                               sample_z,
                                               epochs=1,
                                               verbose=0)
                epoch_loss['rec_Z'].append(h.history['loss'])
Ejemplo n.º 6
0
    def plot_latent_representation(self, lb_images, ul_images, epoch):
        """
        Plot a 4-row image, where the first column shows the input image and the following columns
        each of the 8 channels of the spatial latent representation.
        :param lb_images:   a list of 2 4-dim arrays of images + corresponding masks
        :param ul_images:   a list of 4-dim image arrays
        :param epoch    :   the epoch number
        """

        # combine labelled and unlabelled images and randomly sample 4 examples
        images = lb_images[0]  # [el[0] for el in lb_images]
        current_select = epoch % images.shape[3]
        if len(ul_images) > 0:
            images = np.concatenate([images, ul_images], axis=0)
            x = np.concatenate([images[0:2], ul_images[0:2]], axis=0)
        else:
            x = utils.data_utils.sample(images, nb_samples=4, seed=self.conf.seed)

        # plot S
        s = self.enc_anatomy.predict(x)

        rows = [np.concatenate([x[i, :, :, current_select]] + [s[i, :, :, s_chn] for s_chn in range(s.shape[-1])], axis=1)
                for i in range(x.shape[0])]
        im_plot = np.concatenate(rows, axis=0)
        imsave(self.lr_folder + '/s_lr_epoch_%d.png' % epoch, im_plot)
        # harric modified

        plt.figure()
        plt.imshow(im_plot, cmap='gray')
        plt.xticks([])
        plt.yticks([])
        plt.close()

        if self.conf.rounding == 'decoder':
            s = self.round_model.predict(x)
            rows = [np.concatenate([x[i, :, :, 0]] + [s[i, :, :, s_chn] for s_chn in range(s.shape[-1])], axis=1)
                   for i in range(x.shape[0])]
            im_plot = np.concatenate(rows, axis=0)
            imsave(self.lr_folder + '/srnd_lr_epoch_%d.png' % epoch, im_plot)
            # harric modifiedd

            plt.figure()
            plt.imshow(im_plot, cmap='gray')
            plt.xticks([])
            plt.yticks([])
            plt.close()

        # plot Z
        enc_modality_inputs = [self.enc_anatomy.predict(images), images]
        z, _ = self.enc_modality.predict(enc_modality_inputs)
        gaussian = NormalDistribution()
        real_z = gaussian.sample(z.shape)

        fig, axes = plt.subplots(nrows=z.shape[1], ncols=2, sharex=True, sharey=True, figsize=(10, 8))
        axes[0, 0].set_title('Predicted Z')
        axes[0, 1].set_title('Real Z')
        for i in range(len(axes)):
            axes[i, 0].hist(z[:, i], normed=True, bins=11, range=(-3, 3))
            axes[i, 1].hist(real_z[:, i], normed=True, bins=11, range=(-3, 3))
        axes[0, 0].plot(0, 0)

        plt.savefig(self.lr_folder + '/z_lr_epoch_%d.png' % epoch)
        plt.close()

        means = self.z_mean.predict(enc_modality_inputs)
        variances  = self.z_var.predict(enc_modality_inputs)
        means = np.var(means, axis=0)
        variances = np.mean(np.exp(variances), axis=0)
        with open(self.lr_folder + '/z_means.csv', 'a+') as f:
            f.writelines(', '.join([str(means[i]) for i in range(means.shape[0])]) + '\n')
        with open(self.lr_folder + '/z_vars.csv', 'a+') as f:
            f.writelines(', '.join([str(variances[i]) for i in range(variances.shape[0])]) + '\n')
Ejemplo n.º 7
0
    def train_batch_generators_labeled_pathology(self, epoch_loss, lr_callback):
        """
        Train generator for labelled pathology networks.
        :param epoch_loss:  Dictionary of losses for the epoch
        """
        if self.gen_labelled is not None:
            x, anato_m, patho_m, scanner = next(self.gen_labelled)
            pseudo_health_m = np.zeros(shape=patho_m.shape, dtype=anato_m.dtype)

            batch_size = x.shape[0]  # maybe this differs from conf.batch_size at the last batch.

            # generate pathology labels
            actual_classification_label_batch = np.array(np.array(np.sum(patho_m,axis=(1,2,3)), dtype=bool), dtype=np.uint8)
            one_hot_actual_target = np.eye(2)[actual_classification_label_batch]


            # train for real pathology
            h_rp = self.model.G_trainer_lp_rp. \
                fit([x,patho_m],
                    [np.zeros(batch_size),
                     np.zeros(x.shape[0],dtype=x.dtype),
                     np.ones((batch_size,) + self.model.D_Reconstruction.output_shape[0][1:])],
                    epochs=1, verbose=0, callbacks= [lr_callback])

            s = self.model.Enc_Anatomy.predict(x)
            sample_z1 = NormalDistribution().sample((batch_size, self.conf.num_z))
            h_rp_z = self.model.z_reconstructor. \
                fit([s, patho_m, sample_z1], sample_z1, epochs=1, verbose=0,
                    callbacks=[SingleWeights_Callback(self.conf.w_rec_Z * self.conf.real_pathology_weight_rate + eps,
                                                      model=self.model.z_reconstructor)])

            # train for predicted pathology and predicted anatomy
            h_pppa = self.model.G_trainer_lp_pppa. \
                fit([x, patho_m],
                    [patho_m, patho_m],
                    epochs=1, verbose=0, callbacks= [lr_callback])
            h_pppa_reconst = self.model.G_trainer_lp_pppa_reconst. \
                fit([x, patho_m],
                    [np.zeros(x.shape[0], dtype=x.dtype)],
                    epochs=1, verbose=0, callbacks=[lr_callback])

            # train for predicted pathology and real anatony
            h_ppra = self.model.G_trainer_lp_ppra. \
                fit([x, anato_m, patho_m],
                    [patho_m, patho_m],
                    epochs=1, verbose=0, callbacks= [lr_callback])
            h_ppra_reconst = self.model.G_trainer_lp_ppra_reconst. \
                fit([x, anato_m, patho_m],
                    [np.zeros(x.shape[0], dtype=x.dtype)],
                    epochs=1, verbose=0, callbacks=[lr_callback])

            # train for ratio-based triplet loss
            h_triplet = self.model.G_trainer_lp_triplet. \
                fit([x, patho_m, pseudo_health_m],
                    [np.zeros(batch_size),
                     np.ones((batch_size,) + self.model.D_Reconstruction.output_shape[0][1:])],
                    epochs=1, verbose=0, callbacks= [lr_callback])

            epoch_loss['Triplet_RP'].append(h_triplet.history['Triplet_RP_loss'][-1])

            epoch_loss['SegDice_Patho_PPPA'].append(h_pppa.history['Dice_Patho_loss'][-1])
            epoch_loss['SegCrossEntropy_Patho_PPPA'].append(h_pppa.history['CrossEntropy_Patho_loss'][-1])
            epoch_loss['SegDice_Patho_PPRA'].append(h_ppra.history['Dice_Patho_loss'][-1])
            epoch_loss['SegCrossEntropy_Patho_PPRA'].append(h_ppra.history['CrossEntropy_Patho_loss'][-1])

            epoch_loss['Adv_Reconstruction_Generator_RP'].append([h_rp.history['D_Reconstruction_loss'][-1]])
            epoch_loss['KL_ActualPathology_RP'].append([h_rp.history['Enc_Modality_loss'][-1]])

            epoch_loss['Reconstruct_X_RP'].append(h_rp.history['Reconstructor_RP_loss'][-1])
            epoch_loss['Reconstruct_X_PPPA_LP'].append(h_pppa_reconst.history['loss'][-1])
            epoch_loss['Reconstruct_X_PPRA_LP'].append(h_ppra_reconst.history['loss'][-1])





            epoch_loss['Reconstruct_Z_RP'].append([h_rp_z.history['loss'][-1] /
                                                   (self.conf.w_rec_Z * self.conf.real_pathology_weight_rate + eps)])


            return h_pppa.history['loss'][-1] + h_ppra.history['loss'][-1] \
                   + h_pppa_reconst.history['loss'][-1] + h_ppra_reconst.history['loss'][-1] \
                   + h_rp.history['loss'][-1] \
                   + h_triplet.history['loss'][-1] + h_rp_z.history['loss'][-1]
Ejemplo n.º 8
0
    def train_batch_generators_unlabeld_pathology(self, epoch_loss, lr_callback):
        if self.gen_unlabelled is not None:
            x, anato_m, _, scanner = next(self.gen_unlabelled)
            pseudo_health_m = np.zeros(shape=tuple(anato_m.shape[:-1]
                                                   + (self.conf.num_pathology_masks,)), dtype=anato_m.dtype)
            batch_size = x.shape[0]

            # actual_classification_label_batch = np.array(np.array(np.sum(_, axis=(1, 2, 3)), dtype=bool),
            #                                              dtype=np.uint8)
            # one_hot_actual_target = np.eye(2)[actual_classification_label_batch]
            # one_hot_pseudo_health = np.concatenate([np.ones(shape=(x.shape[0], 1), dtype=one_hot_actual_target.dtype),
            #                                         np.zeros(shape=(x.shape[0], 1), dtype=one_hot_actual_target.dtype)],
            #                                        axis=1)

            h_pppa = self.model.G_trainer_up_pppa. \
                fit([x],
                    [np.zeros(batch_size),
                     np.ones((batch_size,) + self.model.D_Reconstruction.output_shape[0][1:])],
                    epochs=1, verbose=0, callbacks= [lr_callback])
            h_pppa_reconst = self.model.G_trainer_up_pppa_reconst.\
                fit([x], x,
                    epochs=1, verbose=0, callbacks=[lr_callback])


            h_ppra = self.model.G_trainer_up_ppra. \
                fit([x, anato_m],
                    [np.zeros(batch_size),
                     np.ones((batch_size,) + self.model.D_Reconstruction.output_shape[0][1:]),
                     ],
                    epochs=1, verbose=0, callbacks= [lr_callback])
            h_ppra_reconst = self.model.G_trainer_up_ppra_reconst. \
                fit([x, anato_m],
                    [x],
                    epochs=1, verbose=0, callbacks=[lr_callback])

            h_anatomy = self.model.G_trainer_up_anatomy. \
                fit([x],
                    [anato_m, anato_m],
                    epochs=1, verbose=0, callbacks= [lr_callback])

            h_triplet = self.model.G_trainer_up_triplet. \
                fit([x, pseudo_health_m, anato_m],
                    [np.zeros(batch_size),
                     np.zeros(batch_size),
                     np.ones((batch_size,) + self.model.D_Reconstruction.output_shape[0][1:])],
                    epochs=1, verbose=0, callbacks= [lr_callback])

            epoch_loss['Triplet_PPPA'].append(h_triplet.history['Triplet_PPPA_loss'][-1])
            epoch_loss['Triplet_PPRA'].append(h_triplet.history['Triplet_PPRA_loss'][-1])

            epoch_loss['SegDice_Anato'].append(h_anatomy.history['Dice_Anato_loss'][-1])
            epoch_loss['SegCrossEntropy_Anato'].append(h_anatomy.history['CrossEntropy_Anato_loss'][-1])

            epoch_loss['Adv_Reconstruction_Generator_PPPA'].append([h_pppa.history['D_Reconstruction_loss'][-1]])
            epoch_loss['KL_ActualPathology_PPPA'].append([h_pppa.history['Enc_Modality_loss'][-1]])
            epoch_loss['Reconstruct_X_PPPA_UP'].append(h_pppa_reconst.history['loss'][-1])
            # #
            epoch_loss['Adv_Reconstruction_Generator_PPRA'].append([h_ppra.history['D_Reconstruction_loss'][-1]])
            epoch_loss['KL_ActualPathology_PPRA'].append([h_ppra.history['Enc_Modality_loss'][-1]])
            epoch_loss['Reconstruct_X_PPRA_UP'].append(h_ppra_reconst.history['loss'][-1])

            s = self.model.Enc_Anatomy.predict(x)
            predicted_anatomy = self.model.Segmentor.predict(s)
            predicted_pathology_from_predicted_anatomy \
                = self.model.Enc_Pathology.predict(np.concatenate([x, predicted_anatomy], axis=-1))
            m_anato_backgroud = np.ones(shape=anato_m.shape[:-1] + (1,))
            for ii in range(anato_m.shape[-1]):
                m_anato_backgroud = m_anato_backgroud - anato_m[:, :, :, ii:ii + 1]
            m_anato_with_background = np.concatenate([anato_m, m_anato_backgroud], axis=-1)
            predicted_pathology_from_real_anatomy = self.model.Enc_Pathology.predict(
                np.concatenate([x, m_anato_with_background], axis=-1))
            sample_z2 = NormalDistribution().sample((batch_size, self.conf.num_z))
            sample_z3 = NormalDistribution().sample((batch_size, self.conf.num_z))

            h_pppa_z = self.model.z_reconstructor. \
                fit([s, predicted_pathology_from_predicted_anatomy[:, :, :, :-1],
                     sample_z2], sample_z2, epochs=1, verbose=0,
                    callbacks=[SingleWeights_Callback(self.conf.w_rec_Z * self.conf.pred_pathology_weight_rate * self.conf.pred_anatomy_weight_rate + eps,
                                                      model=self.model.z_reconstructor),lr_callback])


            h_ppra_z = self.model.z_reconstructor. \
                fit([s, predicted_pathology_from_real_anatomy[:, :, :, :-1],
                     sample_z3], sample_z3, epochs=1, verbose=0,
                    callbacks=[SingleWeights_Callback(self.conf.w_rec_Z * self.conf.pred_pathology_weight_rate * self.conf.real_anatomy_weight_rate + eps,
                                                      model=self.model.z_reconstructor),lr_callback])


            epoch_loss['Reconstruct_Z_PPPA']. \
                append([h_pppa_z.history['loss'][-1] / (self.conf.w_rec_Z
                                                        * self.conf.pred_pathology_weight_rate
                                                        * self.conf.pred_anatomy_weight_rate + eps)])
            epoch_loss['Reconstruct_Z_PPRA']. \
                append([h_ppra_z.history['loss'][-1] / (self.conf.w_rec_Z
                                                        * self.conf.pred_pathology_weight_rate
                                                        * self.conf.real_anatomy_weight_rate + eps)])

            return h_anatomy.history['loss'][-1] \
                   + h_pppa_reconst.history['loss'][-1] + h_ppra_reconst.history['loss'][-1] \
                   + h_pppa.history['loss'][-1] + h_ppra.history['loss'][-1] \
                   + h_triplet.history['loss'][-1] \
                   + h_pppa_z.history['loss'][-1] + h_ppra_z.history['loss'][-1]
Ejemplo n.º 9
0
    def train_batch_generators(self, epoch_loss):
        """
        Train generator/segmentation networks.
        :param epoch_loss:  Dictionary of losses for the epoch
        """
        num_mod = len(self.model.modalities)

        if self.conf.l_mix > 0:
            x1, x2, m1, m2 = next(self.gen_labelled)
            [x1, x2, m1, m2] = self.align_batches([x1, x2, m1, m2])
            batch_size = x1.shape[0]  # maybe this differs from conf.batch_size at the last batch.
            dm_shape = (batch_size,) + self.model.D_Mask.output_shape[1:]
            ones_m = np.ones(shape=dm_shape)

            # Train labelled path (supervised_model)
            all_outputs = [m1, m2, m2, m2, m1, m1] + \
                          [ones_m for _ in range(num_mod * 3)] + \
                          [x1, x2, x2, x2, x1, x1] + \
                          [np.zeros(batch_size) for _ in range(num_mod * 3)]
            h = self.model.supervised_trainer.fit([x1, x2], all_outputs, epochs=1, verbose=0)
            epoch_loss['supervised_Mask'].append(np.mean(h.history['Segmentor_loss']))
            epoch_loss['adv_M'].append(np.mean(h.history['D_Mask_loss']))
            epoch_loss['rec_X'].append(np.mean(h.history['Decoder_loss']))
            epoch_loss['KL'].append(np.mean(h.history['Enc_Modality_loss']))

            # Train Z Regressor
            norm = NormalDistribution()
            s_list = [self.model.Encoders_Anatomy[i].predict(x) for i, x in enumerate([x1, x2])]
            s1_def, s1_fused = self.model.Anatomy_Fuser.predict(s_list)
            s2_def, s2_fused = self.model.Anatomy_Fuser.predict(list(reversed(s_list)))
            s_list += [s1_def, s1_fused]
            s_list += [s2_def, s2_fused]
            z_list = [norm.sample((batch_size, self.conf.num_z)) for _ in range(num_mod * 3)]
            h = self.model.Z_Regressor.fit(s_list + z_list, z_list, epochs=1, verbose=0)
            epoch_loss['rec_Z'].append(np.mean(h.history['loss']))

        # Train unlabelled path
        if self.conf.l_mix < 1:
            x1, x2, m1 = next(self.gen_unlabelled)
            [x1, x2, m1] = self.align_batches([x1, x2, m1])
            batch_size = x1.shape[0]  # maybe this differs from conf.batch_size at the last batch.
            dm_shape = (batch_size,) + self.model.D_Mask.output_shape[1:]
            ones_m = np.ones(shape=dm_shape)

            # Train unlabelled path (G_model)
            all_outputs = [m1, m1, m1] + \
                          [ones_m for _ in range(num_mod * 3)] + \
                          [x1, x2, x2, x2, x1, x1] + \
                          [np.zeros(batch_size) for _ in range(num_mod * 3)]
            h = self.model.unsupervised_trainer.fit([x1, x2], all_outputs, epochs=1, verbose=0)
            epoch_loss['supervised_Mask'].append(np.mean(h.history['Segmentor_loss']))
            epoch_loss['adv_M'].append(np.mean(h.history['D_Mask_loss']))
            epoch_loss['rec_X'].append(np.mean(h.history['Decoder_loss']))
            epoch_loss['KL'].append(np.mean(h.history['Enc_Modality_loss']))

            # Train Z Regressor
            norm = NormalDistribution()
            s_list = [self.model.Encoders_Anatomy[i].predict(x) for i, x in enumerate([x1, x2])]
            s1_def, s1_fused = self.model.Anatomy_Fuser.predict(s_list)
            s2_def, s2_fused = self.model.Anatomy_Fuser.predict(list(reversed(s_list)))
            s_list += [s1_def, s1_fused]
            s_list += [s2_def, s2_fused]
            z_list = [norm.sample((batch_size, self.conf.num_z)) for _ in range(num_mod * 3)]
            h = self.model.Z_Regressor.fit(s_list + z_list, z_list, epochs=1, verbose=0)
            epoch_loss['rec_Z'].append(np.mean(h.history['loss']))