Exemple #1
0
 def __init__(self, args):
     self.args = args
     torch.manual_seed(self.args.seed)
     np.random.seed(self.args.seed)
     print('{} detection...'.format(args.dataset))
     white_noise = dp.DatasetReader(white_noise=self.args.dataset,
                                    data_path=data_path,
                                    data_source=args.data_source,
                                    len_seg=self.args.len_seg)
     _, self.testset = white_noise(args.net_name)
     self.spots = np.load('{}/spots.npy'.format(info_path))
     self.AE = AutoEncoder(args)
     self.latent = np.load('{}/features/{}.npy'.format(
         save_path, self.file_name()))
 def __init__(self, args):
     self.args = args
     white_noise = DatasetReader(white_noise='W-1',
                                 data_path=data_path,
                                 data_source=args.data_source,
                                 len_seg=self.args.len_seg)
     self.dataset, _ = white_noise(args.net_name)
     self.spots = np.load('{}/spots.npy'.format(info_path))
     self.AE = AutoEncoder(args)
     self.font = {
         'family': 'Arial',
         'style': 'normal',
         'weight': 'bold',
         'size': 10,
         'color': 'k',
     }
Exemple #3
0
def train_model(x_data, y_data):
    model = None

    if args.name == "VAE":
        model = vae(args, device).to(device)
    elif args.name == "AutoEncoder":
        model = AutoEncoder(args, device).to(device)

    for epoch in range(args.nb_epochs):
        x_train, y_train = Utils.SuffleData(x_data, y_data, args.batch_size)

        loss = model.learn(x_train)

        if epoch % args.log_interval == 0:
            print('Epoch {:4d}/{} loss: {:.6f} '.format(epoch, args.nb_epochs, loss))

    return model
Exemple #4
0
 def __init__(self, args):
     self.args = args
     torch.manual_seed(self.args.seed)
     np.random.seed(self.args.seed)
     print('> Training arguments:')
     for arg in vars(args):
         print('>>> {}: {}'.format(arg, getattr(args, arg)))
     white_noise = dp.DatasetReader(white_noise=self.args.dataset,
                                    data_path=data_path,
                                    data_source=args.data_source,
                                    len_seg=self.args.len_seg)
     dataset, _ = white_noise(args.net_name)
     self.data_loader = DataLoader(dataset=dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False)
     self.spots = np.load('{}/spots.npy'.format(info_path))
     self.AE = AutoEncoder(args).to(device)  # AutoEncoder
     self.AE.apply(self.weights_init)
     self.criterion = nn.MSELoss()
     self.vis = visdom.Visdom(
         env='{}'.format(self.file_name()),
         log_to_filename='{}/visualization/{}.log'.format(
             save_path, self.file_name()))
     plt.figure(figsize=(15, 15))
Exemple #5
0
def baseline_ae_model(region_tensors, is_training, encoded_dims):
    squared_euclidean = squared_x = None
    for name in region_tensors:
        tensors = region_tensors[name]['tensors']
        filters = region_tensors[name]['filters']
        reuse = region_tensors[name]['reuse']
        scope = region_tensors[name]['scope']
        ae = AutoEncoder(tensors, encoded_dims=encoded_dims, filters=filters, is_training=is_training, reuse=reuse, name='compressor_{}'.format(scope))
        sq_x, sq_euclidean, cosine_similarity = reconstruction_distances(ae.input_tensor, ae.reconstruction)
        with tf.name_scope('latent_variables'):
            if squared_euclidean is None:
                squared_euclidean = sq_euclidean
                squared_x = sq_x
            else:
                squared_euclidean = squared_euclidean + sq_euclidean
                squared_x = squared_x + sq_x

        return tf.sqrt(squared_euclidean) / tf.sqrt(squared_x), ae.flatten_encoded, tf.reduce_mean(squared_euclidean)
class Reconstruction:
    def __init__(self, args):
        self.args = args
        white_noise = DatasetReader(white_noise='W-1',
                                    data_path=data_path,
                                    data_source=args.data_source,
                                    len_seg=self.args.len_seg)
        self.dataset, _ = white_noise(args.net_name)
        self.spots = np.load('{}/spots.npy'.format(info_path))
        self.AE = AutoEncoder(args)
        self.font = {
            'family': 'Arial',
            'style': 'normal',
            'weight': 'bold',
            'size': 10,
            'color': 'k',
        }

    def file_name(self):
        if self.args.net_name == 'MLP':
            return '{}_{}_{}_{}_{}_{}'.format(self.args.model_name,
                                              self.args.net_name,
                                              self.args.len_seg,
                                              self.args.optimizer,
                                              self.args.learning_rate,
                                              self.args.num_epoch)
        else:
            return '{}_{}_{}_{}_{}_{}_{}'.format(
                self.args.model_name, self.args.net_name, self.args.len_seg,
                self.args.optimizer, self.args.learning_rate,
                self.args.num_epoch, self.args.num_hidden_map)

    def load_model(self):
        path = '{}/models/{}/{}.model'.format(save_path, self.args.model_name,
                                              self.file_name())
        self.AE.load_state_dict(
            torch.load(path,
                       map_location=torch.device(device)))  # Load AutoEncoder

    def show_reconstruct(self):
        self.load_model()
        fig, axs = plt.subplots(nrows=int(len(self.spots) / 2),
                                ncols=2,
                                figsize=(10, 10),
                                sharey=True)
        seg_idx = self.args.seg_idx
        with torch.no_grad():
            num_seg = int(self.dataset.shape[0] / len(self.spots))
            spots_l1, spots_l2 = np.hsplit(self.spots, 2)
            for i, (spot_l1, spot_l2) in enumerate(zip(spots_l1, spots_l2)):
                # L1 sensors
                x = self.dataset[i * num_seg + seg_idx]
                x = x.to(device)
                axs[i][0].set_title('{}-{}'.format(spot_l1, seg_idx),
                                    fontdict=self.font)
                axs[i][0].plot(x.view(-1).detach().cpu().numpy(),
                               c='b',
                               lw=1,
                               label='Original')
                if self.args.net_name == 'Conv2D':
                    x = x.unsqueeze(0).unsqueeze(2)
                x_hat, _, _ = self.AE(x)
                axs[i][0].plot(x_hat.view(-1).detach().cpu().numpy(),
                               ls='--',
                               lw=1,
                               c='r',
                               label='Reconstructed')
                axs[i][0].axvline(x=127, ls='--', c='k')
                axs[i][0].axvline(x=255, ls='--', c='k')
                axs[i][0].set_xticks(
                    np.linspace(self.args.dim_input / 6,
                                self.args.dim_input - self.args.dim_input / 6,
                                3))
                axs[i][0].set_xticklabels(['NS', 'EW', 'V'])
                axs[i][0].set_ylabel('Amplitude', fontdict=self.font)
                axs[i][0].legend(loc='upper center', prop={'size': 8})
                # L2 sensors
                x = self.dataset[(i + 5) * num_seg + seg_idx]
                x = x.to(device)
                axs[i][1].plot(x.view(-1).detach().cpu().numpy(),
                               c='b',
                               lw=1,
                               label='Original')
                axs[i][1].set_title('{}-{}'.format(spot_l2, seg_idx),
                                    fontdict=self.font)
                if self.args.net_name == 'Conv2D':
                    x = x.unsqueeze(0).unsqueeze(2)
                x_hat, _, _ = self.AE(x)
                axs[i][1].plot(x_hat.view(-1).detach().cpu().numpy(),
                               ls='--',
                               lw=1,
                               c='r',
                               label='Reconstructed')
                axs[i][1].axvline(x=127, ls='--', c='k')
                axs[i][1].axvline(x=255, ls='--', c='k')
                axs[i][1].set_xticks(
                    np.linspace(self.args.dim_input / 6,
                                self.args.dim_input - self.args.dim_input / 6,
                                3))
                axs[i][1].set_xticklabels(['NS', 'EW', 'V'])
                axs[i][1].legend(loc='upper center', prop={'size': 8})
            plt.tight_layout()

    def show_latent_reconstruction(self):
        self.load_model()
        fig, axs = plt.subplots(nrows=int(len(self.spots) / 2),
                                ncols=2,
                                figsize=(10, 10))
        seg_idx = self.args.seg_idx
        with torch.no_grad():
            num_seg = int(self.dataset.shape[0] / len(self.spots))
            spots_l1, spots_l2 = np.hsplit(self.spots, 2)
            for i, (spot_l1, spot_l2) in enumerate(zip(spots_l1, spots_l2)):
                # L1 sensors
                x = self.dataset[i * num_seg + seg_idx]
                x = x.to(device)
                axs[i][0].set_title('{}-{}'.format(spot_l1, seg_idx))
                if self.args.net_name == 'Conv2D':
                    x = x.unsqueeze(0).unsqueeze(2)
                _, z, z_hat = self.AE(x)
                axs[i][0].plot(z.view(-1).detach().cpu().numpy(),
                               ls='--',
                               lw=1,
                               c='b',
                               label='Original')
                axs[i][0].plot(z_hat.view(-1).detach().cpu().numpy(),
                               ls='--',
                               lw=1,
                               c='r',
                               label='Reconstructed')
                axs[i][0].set_xticks(np.linspace(0, z.view(-1).size(0) - 1, 3))
                axs[i][0].set_xticklabels([1, '...', z.view(-1).size(0)])
                axs[i][0].legend(loc='upper center', prop={'size': 8})
                # L2 sensors
                x = self.dataset[(i + 5) * num_seg + seg_idx]
                x = x.to(device)
                axs[i][1].set_title('{}-{}'.format(spot_l2, seg_idx))
                if self.args.net_name == 'Conv2D':
                    x = x.unsqueeze(0).unsqueeze(2)
                _, z, z_hat = self.AE(x)
                axs[i][1].plot(z.view(-1).detach().cpu().numpy(),
                               ls='--',
                               lw=1,
                               c='b',
                               label='Original')
                axs[i][1].plot(z_hat.view(-1).detach().cpu().numpy(),
                               ls='--',
                               lw=1,
                               c='r',
                               label='Reconstructed')
                axs[i][1].set_xticks(np.linspace(0, z.view(-1).size(0) - 1, 3))
                axs[i][1].set_xticklabels([1, '...', z.view(-1).size(0)])
                axs[i][1].legend(loc='upper center', prop={'size': 8})
            plt.tight_layout()
from tensorflow.python.keras.datasets import mnist
from models.AutoEncoder import AutoEncoder
import tensorflow as tf
import numpy as np

(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))

autoencoder = AutoEncoder()
autoencoder.build()
# autoencoder.train(x_train, x_test)
Exemple #8
0
class BaseExperiment:
    def __init__(self, args):
        self.args = args
        torch.manual_seed(self.args.seed)
        np.random.seed(self.args.seed)
        print('> Training arguments:')
        for arg in vars(args):
            print('>>> {}: {}'.format(arg, getattr(args, arg)))
        white_noise = dp.DatasetReader(white_noise=self.args.dataset,
                                       data_path=data_path,
                                       data_source=args.data_source,
                                       len_seg=self.args.len_seg)
        dataset, _ = white_noise(args.net_name)
        self.data_loader = DataLoader(dataset=dataset,
                                      batch_size=args.batch_size,
                                      shuffle=False)
        self.spots = np.load('{}/spots.npy'.format(info_path))
        self.AE = AutoEncoder(args).to(device)  # AutoEncoder
        self.AE.apply(self.weights_init)
        self.criterion = nn.MSELoss()
        self.vis = visdom.Visdom(
            env='{}'.format(self.file_name()),
            log_to_filename='{}/visualization/{}.log'.format(
                save_path, self.file_name()))
        plt.figure(figsize=(15, 15))

    def select_optimizer(self, model):
        if self.args.optimizer == 'Adam':
            optimizer = optim.Adam(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=self.args.learning_rate,
                betas=(0.5, 0.999),
            )
        elif self.args.optimizer == 'AdaBelief':
            optimizer = AdaBelief(model.parameters(),
                                  lr=self.args.learning_rate,
                                  betas=(0.5, 0.999))
        elif self.args.optimizer == 'RMS':
            optimizer = optim.RMSprop(filter(lambda p: p.requires_grad,
                                             model.parameters()),
                                      lr=self.args.learning_rate)
        elif self.args.optimizer == 'SGD':
            optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                         model.parameters()),
                                  lr=self.args.learning_rate,
                                  momentum=0.9)
        elif self.args.optimizer == 'Adagrad':
            optimizer = optim.Adagrad(filter(lambda p: p.requires_grad,
                                             model.parameters()),
                                      lr=self.args.learning_rate)
        elif self.args.optimizer == 'Adadelta':
            optimizer = optim.Adadelta(filter(lambda p: p.requires_grad,
                                              model.parameters()),
                                       lr=self.args.learning_rate)
        return optimizer

    def weights_init(self, m):
        initializers = {
            'xavier_uniform_': nn.init.xavier_uniform_,
            'xavier_normal_': nn.init.xavier_normal,
            'orthogonal_': nn.init.orthogonal_,
            'kaiming_normal_': nn.init.kaiming_normal_
        }
        initializer = initializers[self.args.initializer]
        if isinstance(m, nn.Linear):
            initializer(m.weight)
            m.bias.data.fill_(0)
        elif isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

    def file_name(self):
        if self.args.net_name == 'MLP':
            return '{}_{}_{}_{}_{}_{}'.format(self.args.model_name,
                                              self.args.net_name,
                                              self.args.len_seg,
                                              self.args.optimizer,
                                              self.args.learning_rate,
                                              self.args.num_epoch)
        else:
            return '{}_{}_{}_{}_{}_{}_{}'.format(
                self.args.model_name, self.args.net_name, self.args.len_seg,
                self.args.optimizer, self.args.learning_rate,
                self.args.num_epoch, self.args.num_hidden_map)

    def train(self):
        optimizer = self.select_optimizer(self.AE)
        best_loss = 100.
        best_epoch = 1
        lh = {}
        losses, mses_x, mses_z = [], [], []
        for epoch in range(self.args.num_epoch):
            t0 = time.time()
            if self.args.net_name == 'MLP':
                if self.args.model_name == 'VAE':
                    f = torch.zeros(len(self.data_loader.dataset),
                                    int(self.args.dim_feature / 2))
                else:
                    f = torch.zeros(len(self.data_loader.dataset),
                                    int(self.args.dim_feature))
            else:
                f = torch.zeros(len(self.data_loader.dataset),
                                self.args.num_hidden_map, 1, 8)
            idx = 0
            for _, sample_batched in enumerate(self.data_loader):
                batch_size = sample_batched.size(0)
                x = sample_batched.to(device)
                if self.args.net_name == 'Conv2D': x = x.unsqueeze(2)
                if self.args.model_name == 'VAE':
                    x_hat, z, z_kld = self.AE(x)
                    loss = self.criterion(x_hat, x)
                    elbo = -loss - 1.0 * z_kld
                    loss = -elbo
                else:
                    x_hat, z, z_hat = self.AE(x)
                    mse_x = self.criterion(x_hat, x)
                    mse_z = self.criterion(z_hat, z)
                    loss = self.args.beta * mse_x + (1 -
                                                     self.args.beta) * mse_z
                    loss = loss
                f[idx:idx + batch_size] = z
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                idx += batch_size
            t1 = time.time()
            if self.args.model_name == 'VAE':
                print('\033[1;31mEpoch: {}\033[0m\t'
                      '\033[1;32mReconstruction loss: {:5f}\033[0m\t'
                      '\033[1;33mKL Divergence: {:5f}\033[0m\t'
                      '\033[1;35mTime cost: {:2f}s\033[0m'.format(
                          epoch + 1, loss.item(), z_kld, t1 - t0))
            else:
                print('\033[1;31mEpoch: {}\033[0m\t'
                      '\033[1;32mLoss: {:5f}\033[0m\t'
                      '\033[1;33mMSE: {:5f}\033[0m\t'
                      '\033[1;34mMSE_latent: {:5f}\033[0m\t'
                      '\033[1;35mTime cost: {:2f}s\033[0m'.format(
                          epoch + 1, loss.item(), mse_x.item(), mse_z.item(),
                          t1 - t0))
            if loss.item() < best_loss:
                best_loss = loss.item()
                best_epoch = epoch + 1
                f = f.detach().numpy()
                path = '{}/models/{}/{}.model'.format(save_path,
                                                      self.args.model_name,
                                                      self.file_name())
                torch.save(self.AE.state_dict(), path)
                np.save(
                    '{}/features/{}.npy'.format(save_path, self.file_name()),
                    f)
            losses.append(loss.item())
            mses_x.append(mse_x.item())
            mses_z.append(mse_z.item())
            self.show_loss(loss, epoch)
            self.show_reconstruction(epoch)
        plt.close()
        lh['Loss'] = losses
        lh['MSE'] = mses_x
        lh['MSE latent'] = mses_z
        lh['Min loss'] = best_loss
        lh['Best epoch'] = best_epoch
        lh = json.dumps(lh, indent=2)
        with open(
                '{}/learning history/{}.json'.format(save_path,
                                                     self.file_name()),
                'w') as f:
            f.write(lh)

    def show_loss(self, loss, epoch):
        self.vis.line(Y=np.array([loss.item()]),
                      X=np.array([epoch + 1]),
                      win='Train loss',
                      opts=dict(title='Train loss'),
                      update='append')

    def show_reconstruction(self, epoch, seg_idx=25):
        plt.clf()
        num_seg = int(self.data_loader.dataset.shape[0] / len(self.spots))
        spots_l1, spots_l2 = np.hsplit(self.spots, 2)
        for i, (spot_l1, spot_l2) in enumerate(zip(spots_l1, spots_l2)):
            # L1 sensors
            plt.subplot(int(len(self.spots) / 2), 2, 2 * i + 1)
            x = self.data_loader.dataset[i * num_seg + seg_idx]
            x = x.to(device)
            plt.plot(x.view(-1).detach().cpu().numpy(), label='original')
            plt.title('A-{}-{}'.format(spot_l1, seg_idx))
            if self.args.net_name == 'Conv2D': x = x.unsqueeze(0).unsqueeze(2)
            x_hat, _, _ = self.AE(x)
            plt.plot(x_hat.view(-1).detach().cpu().numpy(),
                     label='reconstruct')
            plt.axvline(x=127, ls='--', c='k')
            plt.axvline(x=255, ls='--', c='k')
            plt.legend(loc='upper center')
            # L2 sensors
            plt.subplot(int(len(self.spots) / 2), 2, 2 * (i + 1))
            x = self.data_loader.dataset[(i + 5) * num_seg + seg_idx]
            x = x.to(device)
            plt.plot(x.view(-1).detach().cpu().numpy(), label='original')
            plt.title('A-{}-{}'.format(spot_l2, seg_idx))
            if self.args.net_name == 'Conv2D': x = x.unsqueeze(0).unsqueeze(2)
            x_hat, _, _ = self.AE(x)
            plt.plot(x_hat.view(-1).detach().cpu().numpy(),
                     label='reconstruct')
            plt.axvline(x=127, ls='--', c='k')
            plt.axvline(x=255, ls='--', c='k')
            plt.legend(loc='upper center')
        plt.subplots_adjust(hspace=0.5)
        self.vis.matplot(plt,
                         win='Reconstruction',
                         opts=dict(title='Epoch: {}'.format(epoch + 1)))
Exemple #9
0
class DamageDetection:
    def __init__(self, args):
        self.args = args
        torch.manual_seed(self.args.seed)
        np.random.seed(self.args.seed)
        print('{} detection...'.format(args.dataset))
        white_noise = dp.DatasetReader(white_noise=self.args.dataset,
                                       data_path=data_path,
                                       data_source=args.data_source,
                                       len_seg=self.args.len_seg)
        _, self.testset = white_noise(args.net_name)
        self.spots = np.load('{}/spots.npy'.format(info_path))
        self.AE = AutoEncoder(args)
        self.latent = np.load('{}/features/{}.npy'.format(
            save_path, self.file_name()))

    def __call__(self, *args, **kwargs):
        self.test()

    def file_name(self):
        if self.args.net_name == 'MLP':
            return '{}_{}_{}_{}_{}_{}'.format(self.args.model_name,
                                              self.args.net_name,
                                              self.args.len_seg,
                                              self.args.optimizer,
                                              self.args.learning_rate,
                                              self.args.num_epoch)
        else:
            return '{}_{}_{}_{}_{}_{}_{}'.format(
                self.args.model_name, self.args.net_name, self.args.len_seg,
                self.args.optimizer, self.args.learning_rate,
                self.args.num_epoch, self.args.num_hidden_map)

    def damage_index(self, err):
        return 1 - np.exp(-self.args.alpha * err)

    def test(self):
        path = '{}/models/{}/{}.model'.format(save_path, self.args.model_name,
                                              self.file_name())
        self.AE.load_state_dict(
            torch.load(path,
                       map_location=torch.device(device)))  # Load AutoEncoder
        self.AE.eval()
        damage_indices = {}
        size_0 = self.testset.size(0)
        size_1 = self.testset.size(1)
        if self.args.net_name == 'MLP':
            if self.args.model_name == 'VAE':
                feats = torch.zeros(size_0 * size_1,
                                    int(self.args.dim_feature / 2))
            else:
                feats = torch.zeros(size_0 * size_1,
                                    int(self.args.dim_feature))
        else:
            feats = torch.zeros(size_0 * size_1, self.args.num_hidden_map * 8)
        idx = 0
        with torch.no_grad():
            for i, spot in enumerate(self.spots):
                damage_indices[spot] = {}
                x = self.testset[i]
                x_size = x.size(0)
                # z_w1 = self.latent[idx: idx + x_size]
                if self.args.net_name == 'Conv2D': x = x.unsqueeze(2)
                x_hat, z, z_hat = self.AE(x)
                if self.args.net_name == 'Conv2D':  # Flatten
                    z = z.reshape(x_size, -1)
                    z_hat = z_hat.reshape(x_size, -1)
                feats[idx:idx + x_size] = z  # Latent
                loss_x = ((x - x_hat)**2).mean()  # Reconstruction loss
                loss_z = ((z - z_hat)**2).mean()  # Latent loss
                loss = self.args.beta * loss_x.item() + (
                    1 - self.args.beta) * loss_z.item()  # Overall loss
                damage_index = self.damage_index(loss)
                damage_indices[spot]['Reconstruction loss'] = loss_x.item()
                damage_indices[spot]['Latent loss'] = loss_z.item()
                damage_indices[spot]['Damage index'] = damage_index
                print('\033[1;32m[{}]\033[0m\t'
                      '\033[1;31mReconstruction loss: {:5f}\033[0m\t'
                      '\033[1;33mLatent loss: {:5f}\033[0m\t'
                      '\033[1;34mLoss: {:5f}\033[0m\t'
                      '\033[1;35mDamage index: {:5f}\033[0m'.format(
                          spot, loss_x.item(), loss_z.item(), loss,
                          damage_index))
                idx += x_size
        damage_indices = json.dumps(damage_indices, indent=2)
        with open(
                '{}/damage index/{}_{}.json'.format(save_path,
                                                    self.args.dataset,
                                                    self.file_name()),
                'w') as f:
            f.write(damage_indices)
        np.save(
            '{}/features/test/{}_{}.npy'.format(save_path, self.args.dataset,
                                                self.file_name()), feats)
Exemple #10
0
from Preprocessor import Preprocessor
from train.SDNetTrainer import SDNetTrainer
from datasets.STL10 import STL10
from models.AutoEncoder import AutoEncoder
from models.SpotNet import SNet

target_shape = [96, 96, 3]
ae = AutoEncoder(num_layers=4,
                 batch_size=128,
                 target_shape=target_shape,
                 tag='default')
model = SNet(ae,
             batch_size=128,
             target_shape=target_shape,
             disc_pad='SAME',
             tag='default')
data = STL10()
preprocessor = Preprocessor(target_shape=target_shape, augment_color=True)
trainer = SDNetTrainer(model=model,
                       dataset=data,
                       pre_processor=preprocessor,
                       num_epochs=500,
                       init_lr=0.0003,
                       lr_policy='linear',
                       num_gpus=2)
trainer.train_model(None)
Exemple #11
0
def dagmm(region_tensors,
          is_training,
          encoded_dims=2,
          mixtures=3,
          lambda_1=0.1,
          lambda_2=0.005,
          use_cosine_similarity=False,
          latent_dims=2):
    """
    :param region_tensors: restore the related tensors
    :param is_training: a tensorflow placeholder to indicate whether it is in the training phase or not
    :param encoded_dims:
    :param mixtures:
    :param lambda_1:
    :param lambda_2:
    :param use_cosine_similarity:
    :param latent_dims: reduce the dimension of encoded vector to a smaller one
    :return:
    """

    squared_x = squared_euclidean = z = None
    for name in region_tensors:
        tensors = region_tensors[name]['tensors']
        filters = region_tensors[name]['filters']
        reuse = region_tensors[name]['reuse']
        scope = region_tensors[name]['scope']
        ae = AutoEncoder(tensors,
                         encoded_dims=encoded_dims,
                         filters=filters,
                         is_training=is_training,
                         reuse=reuse,
                         name='compressor_{}'.format(scope))
        reduced_latent = base_dense_layer(ae.flatten_encoded,
                                          latent_dims,
                                          'reducer_{}'.format(name),
                                          is_training=is_training,
                                          bn=False,
                                          activation_fn=None)
        sq_x, sq_euclidean, cosine_similarity = reconstruction_distances(
            ae.input_tensor, ae.reconstruction)
        with tf.name_scope('latent_variables'):
            if use_cosine_similarity:
                relative_euclidean = tf.sqrt(sq_euclidean) / tf.sqrt(sq_x)
                relative_euclidean = tf.reshape(relative_euclidean, [-1, 1])
                cosine_similarity = tf.reshape(cosine_similarity, [-1, 1])
                distances = tf.concat([relative_euclidean, cosine_similarity],
                                      axis=1)
            else:
                distances = tf.sqrt(sq_euclidean) / tf.sqrt(sq_x)
                distances = tf.reshape(distances, [-1, 1])

            if squared_x is None:
                squared_x = sq_x
                squared_euclidean = sq_euclidean
                z = tf.concat([reduced_latent, distances], axis=1)
            else:
                squared_x = squared_x + sq_x
                squared_euclidean = squared_euclidean + sq_euclidean
                z = tf.concat([z, reduced_latent, distances], axis=1)
    with tf.name_scope('n_count'):
        n_count = tf.shape(z)[0]
        n_count = tf.cast(n_count, tf.float32)

    estimator = Estimator(mixtures, z, is_training=is_training)
    gammas = estimator.output_tensor

    with tf.variable_scope('gmm_parameters'):
        phis = tf.get_variable('phis',
                               shape=[mixtures],
                               initializer=tf.ones_initializer(),
                               dtype=tf.float32,
                               trainable=False)
        mus = tf.get_variable('mus',
                              shape=[mixtures, z.get_shape()[1]],
                              initializer=tf.ones_initializer(),
                              dtype=tf.float32,
                              trainable=False)

        init_sigmas = 0.5 * np.expand_dims(np.identity(z.get_shape()[1]),
                                           axis=0)
        init_sigmas = np.tile(init_sigmas, [mixtures, 1, 1])
        init_sigmas = tf.constant_initializer(init_sigmas)
        sigmas = tf.get_variable(
            'sigmas',
            shape=[mixtures, z.get_shape()[1],
                   z.get_shape()[1]],
            initializer=init_sigmas,
            dtype=tf.float32,
            trainable=False)

        sums = tf.reduce_sum(gammas, axis=0)
        sums_exp_dims = tf.expand_dims(sums, axis=-1)

        phis_ = sums / n_count
        mus_ = tf.matmul(gammas, z, transpose_a=True) / sums_exp_dims

        def assign_training_phis_mus():
            with tf.control_dependencies(
                [phis.assign(phis_), mus.assign(mus_)]):
                return [tf.identity(phis), tf.identity(mus)]

        phis, mus = tf.cond(is_training, assign_training_phis_mus,
                            lambda: [phis, mus])

        phis_exp_dims = tf.expand_dims(phis, axis=0)
        phis_exp_dims = tf.expand_dims(phis_exp_dims, axis=-1)
        phis_exp_dims = tf.expand_dims(phis_exp_dims, axis=-1)

        zs_exp_dims = tf.expand_dims(z, 1)
        zs_exp_dims = tf.expand_dims(zs_exp_dims, -1)
        mus_exp_dims = tf.expand_dims(mus, 0)
        mus_exp_dims = tf.expand_dims(mus_exp_dims, -1)

        zs_minus_mus = zs_exp_dims - mus_exp_dims

        sigmas_ = tf.matmul(zs_minus_mus, zs_minus_mus, transpose_b=True)
        broadcast_gammas = tf.expand_dims(gammas, axis=-1)
        broadcast_gammas = tf.expand_dims(broadcast_gammas, axis=-1)
        sigmas_ = broadcast_gammas * sigmas_
        sigmas_ = tf.reduce_sum(sigmas_, axis=0)
        sigmas_ = sigmas_ / tf.expand_dims(sums_exp_dims, axis=-1)
        sigmas_ = add_noise(sigmas_)

        def assign_training_sigmas():
            with tf.control_dependencies([sigmas.assign(sigmas_)]):
                return tf.identity(sigmas)

        sigmas = tf.cond(is_training, assign_training_sigmas, lambda: sigmas)

    with tf.name_scope('loss'):
        loss_reconstruction = tf.reduce_mean(squared_euclidean,
                                             name='loss_reconstruction')
        inversed_sigmas = tf.expand_dims(tf.matrix_inverse(sigmas), axis=0)
        inversed_sigmas = tf.tile(inversed_sigmas,
                                  [tf.shape(zs_minus_mus)[0], 1, 1, 1])
        energy = tf.matmul(zs_minus_mus, inversed_sigmas, transpose_a=True)
        energy = tf.matmul(energy, zs_minus_mus)
        energy = tf.squeeze(phis_exp_dims * tf.exp(-0.5 * energy), axis=[2, 3])
        energy_divided_by = tf.expand_dims(tf.sqrt(
            2.0 * math.pi * tf.matrix_determinant(sigmas)),
                                           axis=0) + 1e-12
        energy = tf.reduce_sum(energy / energy_divided_by, axis=1) + 1e-12
        energy = -1.0 * tf.log(energy)
        energy_mean = tf.reduce_sum(energy) / n_count
        loss_sigmas_diag = 1.0 / tf.matrix_diag_part(sigmas)
        loss_sigmas_diag = tf.reduce_sum(loss_sigmas_diag)
        loss = loss_reconstruction + lambda_1 * energy_mean + lambda_2 * loss_sigmas_diag

    return energy, z, loss, loss_reconstruction, energy_mean, loss_sigmas_diag