def __init__(self, args): self.args = args torch.manual_seed(self.args.seed) np.random.seed(self.args.seed) print('{} detection...'.format(args.dataset)) white_noise = dp.DatasetReader(white_noise=self.args.dataset, data_path=data_path, data_source=args.data_source, len_seg=self.args.len_seg) _, self.testset = white_noise(args.net_name) self.spots = np.load('{}/spots.npy'.format(info_path)) self.AE = AutoEncoder(args) self.latent = np.load('{}/features/{}.npy'.format( save_path, self.file_name()))
def __init__(self, args): self.args = args white_noise = DatasetReader(white_noise='W-1', data_path=data_path, data_source=args.data_source, len_seg=self.args.len_seg) self.dataset, _ = white_noise(args.net_name) self.spots = np.load('{}/spots.npy'.format(info_path)) self.AE = AutoEncoder(args) self.font = { 'family': 'Arial', 'style': 'normal', 'weight': 'bold', 'size': 10, 'color': 'k', }
def train_model(x_data, y_data): model = None if args.name == "VAE": model = vae(args, device).to(device) elif args.name == "AutoEncoder": model = AutoEncoder(args, device).to(device) for epoch in range(args.nb_epochs): x_train, y_train = Utils.SuffleData(x_data, y_data, args.batch_size) loss = model.learn(x_train) if epoch % args.log_interval == 0: print('Epoch {:4d}/{} loss: {:.6f} '.format(epoch, args.nb_epochs, loss)) return model
def __init__(self, args): self.args = args torch.manual_seed(self.args.seed) np.random.seed(self.args.seed) print('> Training arguments:') for arg in vars(args): print('>>> {}: {}'.format(arg, getattr(args, arg))) white_noise = dp.DatasetReader(white_noise=self.args.dataset, data_path=data_path, data_source=args.data_source, len_seg=self.args.len_seg) dataset, _ = white_noise(args.net_name) self.data_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False) self.spots = np.load('{}/spots.npy'.format(info_path)) self.AE = AutoEncoder(args).to(device) # AutoEncoder self.AE.apply(self.weights_init) self.criterion = nn.MSELoss() self.vis = visdom.Visdom( env='{}'.format(self.file_name()), log_to_filename='{}/visualization/{}.log'.format( save_path, self.file_name())) plt.figure(figsize=(15, 15))
def baseline_ae_model(region_tensors, is_training, encoded_dims): squared_euclidean = squared_x = None for name in region_tensors: tensors = region_tensors[name]['tensors'] filters = region_tensors[name]['filters'] reuse = region_tensors[name]['reuse'] scope = region_tensors[name]['scope'] ae = AutoEncoder(tensors, encoded_dims=encoded_dims, filters=filters, is_training=is_training, reuse=reuse, name='compressor_{}'.format(scope)) sq_x, sq_euclidean, cosine_similarity = reconstruction_distances(ae.input_tensor, ae.reconstruction) with tf.name_scope('latent_variables'): if squared_euclidean is None: squared_euclidean = sq_euclidean squared_x = sq_x else: squared_euclidean = squared_euclidean + sq_euclidean squared_x = squared_x + sq_x return tf.sqrt(squared_euclidean) / tf.sqrt(squared_x), ae.flatten_encoded, tf.reduce_mean(squared_euclidean)
class Reconstruction: def __init__(self, args): self.args = args white_noise = DatasetReader(white_noise='W-1', data_path=data_path, data_source=args.data_source, len_seg=self.args.len_seg) self.dataset, _ = white_noise(args.net_name) self.spots = np.load('{}/spots.npy'.format(info_path)) self.AE = AutoEncoder(args) self.font = { 'family': 'Arial', 'style': 'normal', 'weight': 'bold', 'size': 10, 'color': 'k', } def file_name(self): if self.args.net_name == 'MLP': return '{}_{}_{}_{}_{}_{}'.format(self.args.model_name, self.args.net_name, self.args.len_seg, self.args.optimizer, self.args.learning_rate, self.args.num_epoch) else: return '{}_{}_{}_{}_{}_{}_{}'.format( self.args.model_name, self.args.net_name, self.args.len_seg, self.args.optimizer, self.args.learning_rate, self.args.num_epoch, self.args.num_hidden_map) def load_model(self): path = '{}/models/{}/{}.model'.format(save_path, self.args.model_name, self.file_name()) self.AE.load_state_dict( torch.load(path, map_location=torch.device(device))) # Load AutoEncoder def show_reconstruct(self): self.load_model() fig, axs = plt.subplots(nrows=int(len(self.spots) / 2), ncols=2, figsize=(10, 10), sharey=True) seg_idx = self.args.seg_idx with torch.no_grad(): num_seg = int(self.dataset.shape[0] / len(self.spots)) spots_l1, spots_l2 = np.hsplit(self.spots, 2) for i, (spot_l1, spot_l2) in enumerate(zip(spots_l1, spots_l2)): # L1 sensors x = self.dataset[i * num_seg + seg_idx] x = x.to(device) axs[i][0].set_title('{}-{}'.format(spot_l1, seg_idx), fontdict=self.font) axs[i][0].plot(x.view(-1).detach().cpu().numpy(), c='b', lw=1, label='Original') if self.args.net_name == 'Conv2D': x = x.unsqueeze(0).unsqueeze(2) x_hat, _, _ = self.AE(x) axs[i][0].plot(x_hat.view(-1).detach().cpu().numpy(), ls='--', lw=1, c='r', label='Reconstructed') axs[i][0].axvline(x=127, ls='--', c='k') axs[i][0].axvline(x=255, ls='--', c='k') axs[i][0].set_xticks( np.linspace(self.args.dim_input / 6, self.args.dim_input - self.args.dim_input / 6, 3)) axs[i][0].set_xticklabels(['NS', 'EW', 'V']) axs[i][0].set_ylabel('Amplitude', fontdict=self.font) axs[i][0].legend(loc='upper center', prop={'size': 8}) # L2 sensors x = self.dataset[(i + 5) * num_seg + seg_idx] x = x.to(device) axs[i][1].plot(x.view(-1).detach().cpu().numpy(), c='b', lw=1, label='Original') axs[i][1].set_title('{}-{}'.format(spot_l2, seg_idx), fontdict=self.font) if self.args.net_name == 'Conv2D': x = x.unsqueeze(0).unsqueeze(2) x_hat, _, _ = self.AE(x) axs[i][1].plot(x_hat.view(-1).detach().cpu().numpy(), ls='--', lw=1, c='r', label='Reconstructed') axs[i][1].axvline(x=127, ls='--', c='k') axs[i][1].axvline(x=255, ls='--', c='k') axs[i][1].set_xticks( np.linspace(self.args.dim_input / 6, self.args.dim_input - self.args.dim_input / 6, 3)) axs[i][1].set_xticklabels(['NS', 'EW', 'V']) axs[i][1].legend(loc='upper center', prop={'size': 8}) plt.tight_layout() def show_latent_reconstruction(self): self.load_model() fig, axs = plt.subplots(nrows=int(len(self.spots) / 2), ncols=2, figsize=(10, 10)) seg_idx = self.args.seg_idx with torch.no_grad(): num_seg = int(self.dataset.shape[0] / len(self.spots)) spots_l1, spots_l2 = np.hsplit(self.spots, 2) for i, (spot_l1, spot_l2) in enumerate(zip(spots_l1, spots_l2)): # L1 sensors x = self.dataset[i * num_seg + seg_idx] x = x.to(device) axs[i][0].set_title('{}-{}'.format(spot_l1, seg_idx)) if self.args.net_name == 'Conv2D': x = x.unsqueeze(0).unsqueeze(2) _, z, z_hat = self.AE(x) axs[i][0].plot(z.view(-1).detach().cpu().numpy(), ls='--', lw=1, c='b', label='Original') axs[i][0].plot(z_hat.view(-1).detach().cpu().numpy(), ls='--', lw=1, c='r', label='Reconstructed') axs[i][0].set_xticks(np.linspace(0, z.view(-1).size(0) - 1, 3)) axs[i][0].set_xticklabels([1, '...', z.view(-1).size(0)]) axs[i][0].legend(loc='upper center', prop={'size': 8}) # L2 sensors x = self.dataset[(i + 5) * num_seg + seg_idx] x = x.to(device) axs[i][1].set_title('{}-{}'.format(spot_l2, seg_idx)) if self.args.net_name == 'Conv2D': x = x.unsqueeze(0).unsqueeze(2) _, z, z_hat = self.AE(x) axs[i][1].plot(z.view(-1).detach().cpu().numpy(), ls='--', lw=1, c='b', label='Original') axs[i][1].plot(z_hat.view(-1).detach().cpu().numpy(), ls='--', lw=1, c='r', label='Reconstructed') axs[i][1].set_xticks(np.linspace(0, z.view(-1).size(0) - 1, 3)) axs[i][1].set_xticklabels([1, '...', z.view(-1).size(0)]) axs[i][1].legend(loc='upper center', prop={'size': 8}) plt.tight_layout()
from tensorflow.python.keras.datasets import mnist from models.AutoEncoder import AutoEncoder import tensorflow as tf import numpy as np (x_train, _), (x_test, _) = mnist.load_data() x_train = x_train.astype('float32') / 255. x_test = x_test.astype('float32') / 255. x_train = np.reshape(x_train, (len(x_train), 28, 28, 1)) x_test = np.reshape(x_test, (len(x_test), 28, 28, 1)) autoencoder = AutoEncoder() autoencoder.build() # autoencoder.train(x_train, x_test)
class BaseExperiment: def __init__(self, args): self.args = args torch.manual_seed(self.args.seed) np.random.seed(self.args.seed) print('> Training arguments:') for arg in vars(args): print('>>> {}: {}'.format(arg, getattr(args, arg))) white_noise = dp.DatasetReader(white_noise=self.args.dataset, data_path=data_path, data_source=args.data_source, len_seg=self.args.len_seg) dataset, _ = white_noise(args.net_name) self.data_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False) self.spots = np.load('{}/spots.npy'.format(info_path)) self.AE = AutoEncoder(args).to(device) # AutoEncoder self.AE.apply(self.weights_init) self.criterion = nn.MSELoss() self.vis = visdom.Visdom( env='{}'.format(self.file_name()), log_to_filename='{}/visualization/{}.log'.format( save_path, self.file_name())) plt.figure(figsize=(15, 15)) def select_optimizer(self, model): if self.args.optimizer == 'Adam': optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.learning_rate, betas=(0.5, 0.999), ) elif self.args.optimizer == 'AdaBelief': optimizer = AdaBelief(model.parameters(), lr=self.args.learning_rate, betas=(0.5, 0.999)) elif self.args.optimizer == 'RMS': optimizer = optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.learning_rate) elif self.args.optimizer == 'SGD': optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.learning_rate, momentum=0.9) elif self.args.optimizer == 'Adagrad': optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.learning_rate) elif self.args.optimizer == 'Adadelta': optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.learning_rate) return optimizer def weights_init(self, m): initializers = { 'xavier_uniform_': nn.init.xavier_uniform_, 'xavier_normal_': nn.init.xavier_normal, 'orthogonal_': nn.init.orthogonal_, 'kaiming_normal_': nn.init.kaiming_normal_ } initializer = initializers[self.args.initializer] if isinstance(m, nn.Linear): initializer(m.weight) m.bias.data.fill_(0) elif isinstance(m, nn.Conv2d): nn.init.normal_(m.weight.data, 0.0, 0.02) def file_name(self): if self.args.net_name == 'MLP': return '{}_{}_{}_{}_{}_{}'.format(self.args.model_name, self.args.net_name, self.args.len_seg, self.args.optimizer, self.args.learning_rate, self.args.num_epoch) else: return '{}_{}_{}_{}_{}_{}_{}'.format( self.args.model_name, self.args.net_name, self.args.len_seg, self.args.optimizer, self.args.learning_rate, self.args.num_epoch, self.args.num_hidden_map) def train(self): optimizer = self.select_optimizer(self.AE) best_loss = 100. best_epoch = 1 lh = {} losses, mses_x, mses_z = [], [], [] for epoch in range(self.args.num_epoch): t0 = time.time() if self.args.net_name == 'MLP': if self.args.model_name == 'VAE': f = torch.zeros(len(self.data_loader.dataset), int(self.args.dim_feature / 2)) else: f = torch.zeros(len(self.data_loader.dataset), int(self.args.dim_feature)) else: f = torch.zeros(len(self.data_loader.dataset), self.args.num_hidden_map, 1, 8) idx = 0 for _, sample_batched in enumerate(self.data_loader): batch_size = sample_batched.size(0) x = sample_batched.to(device) if self.args.net_name == 'Conv2D': x = x.unsqueeze(2) if self.args.model_name == 'VAE': x_hat, z, z_kld = self.AE(x) loss = self.criterion(x_hat, x) elbo = -loss - 1.0 * z_kld loss = -elbo else: x_hat, z, z_hat = self.AE(x) mse_x = self.criterion(x_hat, x) mse_z = self.criterion(z_hat, z) loss = self.args.beta * mse_x + (1 - self.args.beta) * mse_z loss = loss f[idx:idx + batch_size] = z optimizer.zero_grad() loss.backward() optimizer.step() idx += batch_size t1 = time.time() if self.args.model_name == 'VAE': print('\033[1;31mEpoch: {}\033[0m\t' '\033[1;32mReconstruction loss: {:5f}\033[0m\t' '\033[1;33mKL Divergence: {:5f}\033[0m\t' '\033[1;35mTime cost: {:2f}s\033[0m'.format( epoch + 1, loss.item(), z_kld, t1 - t0)) else: print('\033[1;31mEpoch: {}\033[0m\t' '\033[1;32mLoss: {:5f}\033[0m\t' '\033[1;33mMSE: {:5f}\033[0m\t' '\033[1;34mMSE_latent: {:5f}\033[0m\t' '\033[1;35mTime cost: {:2f}s\033[0m'.format( epoch + 1, loss.item(), mse_x.item(), mse_z.item(), t1 - t0)) if loss.item() < best_loss: best_loss = loss.item() best_epoch = epoch + 1 f = f.detach().numpy() path = '{}/models/{}/{}.model'.format(save_path, self.args.model_name, self.file_name()) torch.save(self.AE.state_dict(), path) np.save( '{}/features/{}.npy'.format(save_path, self.file_name()), f) losses.append(loss.item()) mses_x.append(mse_x.item()) mses_z.append(mse_z.item()) self.show_loss(loss, epoch) self.show_reconstruction(epoch) plt.close() lh['Loss'] = losses lh['MSE'] = mses_x lh['MSE latent'] = mses_z lh['Min loss'] = best_loss lh['Best epoch'] = best_epoch lh = json.dumps(lh, indent=2) with open( '{}/learning history/{}.json'.format(save_path, self.file_name()), 'w') as f: f.write(lh) def show_loss(self, loss, epoch): self.vis.line(Y=np.array([loss.item()]), X=np.array([epoch + 1]), win='Train loss', opts=dict(title='Train loss'), update='append') def show_reconstruction(self, epoch, seg_idx=25): plt.clf() num_seg = int(self.data_loader.dataset.shape[0] / len(self.spots)) spots_l1, spots_l2 = np.hsplit(self.spots, 2) for i, (spot_l1, spot_l2) in enumerate(zip(spots_l1, spots_l2)): # L1 sensors plt.subplot(int(len(self.spots) / 2), 2, 2 * i + 1) x = self.data_loader.dataset[i * num_seg + seg_idx] x = x.to(device) plt.plot(x.view(-1).detach().cpu().numpy(), label='original') plt.title('A-{}-{}'.format(spot_l1, seg_idx)) if self.args.net_name == 'Conv2D': x = x.unsqueeze(0).unsqueeze(2) x_hat, _, _ = self.AE(x) plt.plot(x_hat.view(-1).detach().cpu().numpy(), label='reconstruct') plt.axvline(x=127, ls='--', c='k') plt.axvline(x=255, ls='--', c='k') plt.legend(loc='upper center') # L2 sensors plt.subplot(int(len(self.spots) / 2), 2, 2 * (i + 1)) x = self.data_loader.dataset[(i + 5) * num_seg + seg_idx] x = x.to(device) plt.plot(x.view(-1).detach().cpu().numpy(), label='original') plt.title('A-{}-{}'.format(spot_l2, seg_idx)) if self.args.net_name == 'Conv2D': x = x.unsqueeze(0).unsqueeze(2) x_hat, _, _ = self.AE(x) plt.plot(x_hat.view(-1).detach().cpu().numpy(), label='reconstruct') plt.axvline(x=127, ls='--', c='k') plt.axvline(x=255, ls='--', c='k') plt.legend(loc='upper center') plt.subplots_adjust(hspace=0.5) self.vis.matplot(plt, win='Reconstruction', opts=dict(title='Epoch: {}'.format(epoch + 1)))
class DamageDetection: def __init__(self, args): self.args = args torch.manual_seed(self.args.seed) np.random.seed(self.args.seed) print('{} detection...'.format(args.dataset)) white_noise = dp.DatasetReader(white_noise=self.args.dataset, data_path=data_path, data_source=args.data_source, len_seg=self.args.len_seg) _, self.testset = white_noise(args.net_name) self.spots = np.load('{}/spots.npy'.format(info_path)) self.AE = AutoEncoder(args) self.latent = np.load('{}/features/{}.npy'.format( save_path, self.file_name())) def __call__(self, *args, **kwargs): self.test() def file_name(self): if self.args.net_name == 'MLP': return '{}_{}_{}_{}_{}_{}'.format(self.args.model_name, self.args.net_name, self.args.len_seg, self.args.optimizer, self.args.learning_rate, self.args.num_epoch) else: return '{}_{}_{}_{}_{}_{}_{}'.format( self.args.model_name, self.args.net_name, self.args.len_seg, self.args.optimizer, self.args.learning_rate, self.args.num_epoch, self.args.num_hidden_map) def damage_index(self, err): return 1 - np.exp(-self.args.alpha * err) def test(self): path = '{}/models/{}/{}.model'.format(save_path, self.args.model_name, self.file_name()) self.AE.load_state_dict( torch.load(path, map_location=torch.device(device))) # Load AutoEncoder self.AE.eval() damage_indices = {} size_0 = self.testset.size(0) size_1 = self.testset.size(1) if self.args.net_name == 'MLP': if self.args.model_name == 'VAE': feats = torch.zeros(size_0 * size_1, int(self.args.dim_feature / 2)) else: feats = torch.zeros(size_0 * size_1, int(self.args.dim_feature)) else: feats = torch.zeros(size_0 * size_1, self.args.num_hidden_map * 8) idx = 0 with torch.no_grad(): for i, spot in enumerate(self.spots): damage_indices[spot] = {} x = self.testset[i] x_size = x.size(0) # z_w1 = self.latent[idx: idx + x_size] if self.args.net_name == 'Conv2D': x = x.unsqueeze(2) x_hat, z, z_hat = self.AE(x) if self.args.net_name == 'Conv2D': # Flatten z = z.reshape(x_size, -1) z_hat = z_hat.reshape(x_size, -1) feats[idx:idx + x_size] = z # Latent loss_x = ((x - x_hat)**2).mean() # Reconstruction loss loss_z = ((z - z_hat)**2).mean() # Latent loss loss = self.args.beta * loss_x.item() + ( 1 - self.args.beta) * loss_z.item() # Overall loss damage_index = self.damage_index(loss) damage_indices[spot]['Reconstruction loss'] = loss_x.item() damage_indices[spot]['Latent loss'] = loss_z.item() damage_indices[spot]['Damage index'] = damage_index print('\033[1;32m[{}]\033[0m\t' '\033[1;31mReconstruction loss: {:5f}\033[0m\t' '\033[1;33mLatent loss: {:5f}\033[0m\t' '\033[1;34mLoss: {:5f}\033[0m\t' '\033[1;35mDamage index: {:5f}\033[0m'.format( spot, loss_x.item(), loss_z.item(), loss, damage_index)) idx += x_size damage_indices = json.dumps(damage_indices, indent=2) with open( '{}/damage index/{}_{}.json'.format(save_path, self.args.dataset, self.file_name()), 'w') as f: f.write(damage_indices) np.save( '{}/features/test/{}_{}.npy'.format(save_path, self.args.dataset, self.file_name()), feats)
from Preprocessor import Preprocessor from train.SDNetTrainer import SDNetTrainer from datasets.STL10 import STL10 from models.AutoEncoder import AutoEncoder from models.SpotNet import SNet target_shape = [96, 96, 3] ae = AutoEncoder(num_layers=4, batch_size=128, target_shape=target_shape, tag='default') model = SNet(ae, batch_size=128, target_shape=target_shape, disc_pad='SAME', tag='default') data = STL10() preprocessor = Preprocessor(target_shape=target_shape, augment_color=True) trainer = SDNetTrainer(model=model, dataset=data, pre_processor=preprocessor, num_epochs=500, init_lr=0.0003, lr_policy='linear', num_gpus=2) trainer.train_model(None)
def dagmm(region_tensors, is_training, encoded_dims=2, mixtures=3, lambda_1=0.1, lambda_2=0.005, use_cosine_similarity=False, latent_dims=2): """ :param region_tensors: restore the related tensors :param is_training: a tensorflow placeholder to indicate whether it is in the training phase or not :param encoded_dims: :param mixtures: :param lambda_1: :param lambda_2: :param use_cosine_similarity: :param latent_dims: reduce the dimension of encoded vector to a smaller one :return: """ squared_x = squared_euclidean = z = None for name in region_tensors: tensors = region_tensors[name]['tensors'] filters = region_tensors[name]['filters'] reuse = region_tensors[name]['reuse'] scope = region_tensors[name]['scope'] ae = AutoEncoder(tensors, encoded_dims=encoded_dims, filters=filters, is_training=is_training, reuse=reuse, name='compressor_{}'.format(scope)) reduced_latent = base_dense_layer(ae.flatten_encoded, latent_dims, 'reducer_{}'.format(name), is_training=is_training, bn=False, activation_fn=None) sq_x, sq_euclidean, cosine_similarity = reconstruction_distances( ae.input_tensor, ae.reconstruction) with tf.name_scope('latent_variables'): if use_cosine_similarity: relative_euclidean = tf.sqrt(sq_euclidean) / tf.sqrt(sq_x) relative_euclidean = tf.reshape(relative_euclidean, [-1, 1]) cosine_similarity = tf.reshape(cosine_similarity, [-1, 1]) distances = tf.concat([relative_euclidean, cosine_similarity], axis=1) else: distances = tf.sqrt(sq_euclidean) / tf.sqrt(sq_x) distances = tf.reshape(distances, [-1, 1]) if squared_x is None: squared_x = sq_x squared_euclidean = sq_euclidean z = tf.concat([reduced_latent, distances], axis=1) else: squared_x = squared_x + sq_x squared_euclidean = squared_euclidean + sq_euclidean z = tf.concat([z, reduced_latent, distances], axis=1) with tf.name_scope('n_count'): n_count = tf.shape(z)[0] n_count = tf.cast(n_count, tf.float32) estimator = Estimator(mixtures, z, is_training=is_training) gammas = estimator.output_tensor with tf.variable_scope('gmm_parameters'): phis = tf.get_variable('phis', shape=[mixtures], initializer=tf.ones_initializer(), dtype=tf.float32, trainable=False) mus = tf.get_variable('mus', shape=[mixtures, z.get_shape()[1]], initializer=tf.ones_initializer(), dtype=tf.float32, trainable=False) init_sigmas = 0.5 * np.expand_dims(np.identity(z.get_shape()[1]), axis=0) init_sigmas = np.tile(init_sigmas, [mixtures, 1, 1]) init_sigmas = tf.constant_initializer(init_sigmas) sigmas = tf.get_variable( 'sigmas', shape=[mixtures, z.get_shape()[1], z.get_shape()[1]], initializer=init_sigmas, dtype=tf.float32, trainable=False) sums = tf.reduce_sum(gammas, axis=0) sums_exp_dims = tf.expand_dims(sums, axis=-1) phis_ = sums / n_count mus_ = tf.matmul(gammas, z, transpose_a=True) / sums_exp_dims def assign_training_phis_mus(): with tf.control_dependencies( [phis.assign(phis_), mus.assign(mus_)]): return [tf.identity(phis), tf.identity(mus)] phis, mus = tf.cond(is_training, assign_training_phis_mus, lambda: [phis, mus]) phis_exp_dims = tf.expand_dims(phis, axis=0) phis_exp_dims = tf.expand_dims(phis_exp_dims, axis=-1) phis_exp_dims = tf.expand_dims(phis_exp_dims, axis=-1) zs_exp_dims = tf.expand_dims(z, 1) zs_exp_dims = tf.expand_dims(zs_exp_dims, -1) mus_exp_dims = tf.expand_dims(mus, 0) mus_exp_dims = tf.expand_dims(mus_exp_dims, -1) zs_minus_mus = zs_exp_dims - mus_exp_dims sigmas_ = tf.matmul(zs_minus_mus, zs_minus_mus, transpose_b=True) broadcast_gammas = tf.expand_dims(gammas, axis=-1) broadcast_gammas = tf.expand_dims(broadcast_gammas, axis=-1) sigmas_ = broadcast_gammas * sigmas_ sigmas_ = tf.reduce_sum(sigmas_, axis=0) sigmas_ = sigmas_ / tf.expand_dims(sums_exp_dims, axis=-1) sigmas_ = add_noise(sigmas_) def assign_training_sigmas(): with tf.control_dependencies([sigmas.assign(sigmas_)]): return tf.identity(sigmas) sigmas = tf.cond(is_training, assign_training_sigmas, lambda: sigmas) with tf.name_scope('loss'): loss_reconstruction = tf.reduce_mean(squared_euclidean, name='loss_reconstruction') inversed_sigmas = tf.expand_dims(tf.matrix_inverse(sigmas), axis=0) inversed_sigmas = tf.tile(inversed_sigmas, [tf.shape(zs_minus_mus)[0], 1, 1, 1]) energy = tf.matmul(zs_minus_mus, inversed_sigmas, transpose_a=True) energy = tf.matmul(energy, zs_minus_mus) energy = tf.squeeze(phis_exp_dims * tf.exp(-0.5 * energy), axis=[2, 3]) energy_divided_by = tf.expand_dims(tf.sqrt( 2.0 * math.pi * tf.matrix_determinant(sigmas)), axis=0) + 1e-12 energy = tf.reduce_sum(energy / energy_divided_by, axis=1) + 1e-12 energy = -1.0 * tf.log(energy) energy_mean = tf.reduce_sum(energy) / n_count loss_sigmas_diag = 1.0 / tf.matrix_diag_part(sigmas) loss_sigmas_diag = tf.reduce_sum(loss_sigmas_diag) loss = loss_reconstruction + lambda_1 * energy_mean + lambda_2 * loss_sigmas_diag return energy, z, loss, loss_reconstruction, energy_mean, loss_sigmas_diag