Beispiel #1
0
class TEncoder:
    def __init__(self,
                 input_dim,
                 layer_sizes,
                 activations,
                 alpha=0.01,
                 learning_rate=0.001,
                 batch_size=128,
                 n_epochs=40,
                 early_stopping=True,
                 patience=5,
                 v1_compat_mode=False,
                 random_state=42):
        self.input_dim = input_dim
        self.layer_sizes = layer_sizes
        self.activations = activations
        self.alpha = alpha
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.early_stopping = early_stopping
        self.patience = patience
        self.random_state = random_state
        if v1_compat_mode:
            self.session = tf.compat.v1.Session()
        else:
            self.session = tf.Session()
        self.best_epoch_ = None
        self.last_fit_duration_ = None
        self.centroids = None
        self.altered_centroids = None
        self.v1_compat_mode = v1_compat_mode

    def compile(self):
        if self.v1_compat_mode:
            tf.compat.v1.disable_eager_execution()
        self._init_architecture()
        anchor_output = self.forward_pass(self.placeholders['anchor'])
        pos_output = self.forward_pass(self.placeholders['pos'])
        neg_output = self.forward_pass(self.placeholders['neg'])

        ap_norm = tf.norm(tf.square(anchor_output - pos_output),
                          keepdims=True,
                          axis=1)
        an_norm = tf.norm(tf.square(anchor_output - neg_output),
                          keepdims=True,
                          axis=1)
        loss = tf.nn.relu(ap_norm - an_norm + self.alpha)
        loss = tf.reduce_sum(loss)

        if self.v1_compat_mode:
            adam = tf.compat.v1.train.AdamOptimizer
        else:
            adam = tf.train.AdamOptimizer

        train_step = adam(self.learning_rate).minimize(loss)

        self.loss = loss
        self.train_step = train_step

    def forward_pass(self, input_pl):
        output = input_pl
        for i in range(len(self.weights)):
            output = tf.matmul(output, self.weights[i]) + self.biases[i]
            if self.activations[i] == 'relu':
                output = tf.nn.relu(output)
            elif self.activations[i] == '' or self.activations[i] is None:
                pass
            else:
                raise NotImplementedError(
                    "This activation ({}) is not yet implemented.".format(
                        self.activations[i]))
        return output

    def _init_architecture(self):
        if self.v1_compat_mode:
            anchor_pl = tf.compat.v1.placeholder(tf.float32,
                                                 shape=(None, self.input_dim))
            pos_pl = tf.compat.v1.placeholder(tf.float32,
                                              shape=(None, self.input_dim))
            neg_pl = tf.compat.v1.placeholder(tf.float32,
                                              shape=(None, self.input_dim))
        else:
            anchor_pl = tf.placeholder(tf.float32,
                                       shape=(None, self.input_dim))
            pos_pl = tf.placeholder(tf.float32, shape=(None, self.input_dim))
            neg_pl = tf.placeholder(tf.float32, shape=(None, self.input_dim))

        self.placeholders = {'anchor': anchor_pl, 'pos': pos_pl, 'neg': neg_pl}

        weights = []
        biases = []
        i_dim = self.input_dim
        for layer_size in self.layer_sizes:
            w = weight_variable([i_dim, layer_size],
                                v1_compat_mode=self.v1_compat_mode)
            b = bias_variable([layer_size])
            i_dim = layer_size
            weights.append(w)
            biases.append(b)
        self.weights = weights
        self.biases = biases

        self.saver = Saver(self.weights + self.biases)

    def get_fd(self, X_a, X_p, X_n):
        return {
            self.placeholders['anchor']: X_a,
            self.placeholders['pos']: X_p,
            self.placeholders['neg']: X_n
        }

    def eval_var(self, var, X_a, X_p, X_n):
        return var.eval(feed_dict=self.get_fd(X_a, X_p, X_n),
                        session=self.session)

    def fit_idxs(self,
                 triplet_idxs,
                 fetch_method,
                 lods,
                 log_time=False,
                 verbose=False):
        t0 = time.time()

        triplet_idxs = np.array(triplet_idxs)

        if self.early_stopping:
            triplet_idxs, triplet_idxs_val = train_test_split(triplet_idxs,
                                                              shuffle=False)
            self.history = {'loss': [], 'val_loss': []}
        else:
            self.history = {'loss': []}

        n_points = len(triplet_idxs)
        sess = self.session
        if self.v1_compat_mode:
            sess.run(tf.compat.v1.global_variables_initializer())
        else:
            sess.run(tf.global_variables_initializer())

        n_batches = int(np.ceil(n_points / self.batch_size))
        bs = self.batch_size

        best_epoch = -1
        min_err = np.inf

        for e in range(self.n_epochs):
            if self.early_stopping and best_epoch > 0 and e > best_epoch + self.patience:
                exited_early_stopping = True
                break

            triplet_idxs = shuffle(triplet_idxs,
                                   random_state=self.random_state + e)

            loss_value = []
            for i in range(n_batches):
                n_skipped = 0
                if (i % 1000) == 0:
                    if verbose:
                        logging.info("Epoch: {} \t step: {}/{} batches".format(
                            e, i, n_batches))

                triplets_idxs_batch = triplet_idxs[i * bs:(i + 1) * bs, :]
                xa_batch, xp_batch, xn_batch = fetch_method(
                    triplets_idxs_batch, lods)
                batch_loss_value = self.eval_var(self.loss, xa_batch, xp_batch,
                                                 xn_batch)
                if np.isfinite(batch_loss_value):
                    self.train_step.run(feed_dict=self.get_fd(
                        xa_batch, xp_batch, xn_batch),
                                        session=self.session)
                    loss_value.append(batch_loss_value)
                else:
                    n_skipped += 1
            loss_value = np.mean(loss_value)
            self.history['loss'].append(loss_value)
            if not np.isfinite(loss_value):
                logging.warn("Training stopped: nan or inf loss value")
                break
            if not self.early_stopping and verbose:
                logging.info("===> Epoch: {} \t loss: {:.6f}".format(
                    e, loss_value))
            else:
                n_val_batches = int(
                    np.ceil(len(triplet_idxs_val) / self.batch_size))
                val_loss_value = []
                for i in range(n_val_batches):
                    triplets_idxs_batch_val = triplet_idxs_val[i * bs:(i + 1) *
                                                               bs, :]
                    xa_batch_val, xp_batch_val, xn_batch_val = fetch_method(
                        triplets_idxs_batch_val, lods)
                    batch_loss_value = self.eval_var(self.loss, xa_batch_val,
                                                     xp_batch_val,
                                                     xn_batch_val)
                    if np.isfinite(batch_loss_value):
                        val_loss_value.append(batch_loss_value)
                val_loss_value = np.nanmean(val_loss_value)

                self.history['val_loss'].append(val_loss_value)
                if not np.isfinite(val_loss_value):
                    logging.warn(
                        "Training stopped: nan or inf validation loss value")
                    break
                if val_loss_value < min_err:
                    min_err = val_loss_value
                    best_epoch = e
                    self.best_epoch_ = e
                    self.saver.save_weights(self.session)
                    if verbose:
                        logging.info(
                            "===> Epoch: {} \t loss: {:.6f} \t val_loss: {:.6f} ** (new best epoch)"
                            .format(e, loss_value, val_loss_value))
                else:
                    if verbose:
                        logging.info(
                            "===> Epoch: {} \t loss: {:.6f} \t val_loss: {:.6f}"
                            .format(e, loss_value, val_loss_value))

        if self.early_stopping:
            self.saver.restore_weights(self.session)
        else:
            self.saver.save_weights(self.session)

        tend = time.time()
        fitting_time = tend - t0
        self.last_fit_duration_ = fitting_time

        if log_time:
            logging.info(
                "[Triplet fitting time]: {} minutes and {} seconds".format(
                    fitting_time // 60, int(fitting_time % 60)))
        return self.history

    def fit(self, X_a, X_p, X_n, log_time=False):
        assert len(X_a) == len(X_p)
        assert len(X_p) == len(X_n)

        t0 = time.time()

        if self.early_stopping:
            X_a, X_a_val, X_p, X_p_val, X_n, X_n_val = train_test_split(
                X_a, X_p, X_n, shuffle=False)
            self.history = {'loss': [], 'val_loss': []}
        else:
            self.history = {'loss': []}

        n_points = len(X_a)
        sess = self.session
        if self.v1_compat_mode:
            sess.run(tf.compat.v1.global_variables_initializer())
        else:
            sess.run(tf.global_variables_initializer())
        self.history['loss'].append(self.eval_var(self.loss, X_a, X_p, X_n))
        if self.early_stopping:
            self.history['val_loss'].append(
                self.eval_var(self.loss, X_a_val, X_p_val, X_n_val))
        n_batches = int(np.ceil(n_points / self.batch_size))
        bs = self.batch_size

        best_epoch = -1
        min_err = np.inf

        logging.info("Initial loss(es): {}".format(self.history))
        for e in range(self.n_epochs):
            if self.early_stopping and best_epoch > 0 and e > best_epoch + self.patience:
                exited_early_stopping = True
                break

            X_a, X_p, X_n = shuffle(X_a,
                                    X_p,
                                    X_n,
                                    random_state=self.random_state + e)
            for i in range(n_batches):
                if (i % 1000) == 0:
                    logging.info("Epoch: {} \t step: {}/{} batches".format(
                        e, i, n_batches))
                xa_batch = X_a[i * bs:(i + 1) * bs, :]
                xp_batch = X_p[i * bs:(i + 1) * bs, :]
                xn_batch = X_n[i * bs:(i + 1) * bs, :]
                self.train_step.run(feed_dict=self.get_fd(
                    xa_batch, xp_batch, xn_batch),
                                    session=self.session)
            loss_value = self.eval_var(self.loss, X_a, X_p, X_n)
            self.history['loss'].append(loss_value)
            if not self.early_stopping:
                logging.info("===> Epoch: {} \t loss: {:.3f}".format(
                    e, loss_value))
            else:
                val_loss_value = self.eval_var(self.loss, X_a_val, X_p_val,
                                               X_n_val)
                self.history['val_loss'].append(val_loss_value)
                if val_loss_value < min_err:
                    min_err = val_loss_value
                    best_epoch = e
                    self.best_epoch_ = e
                    self.saver.save_weights(self.session)
                    logging.info(
                        "===> Epoch: {} \t loss: {:.3f} \t val_loss: {:.3f} ** (new best epoch)"
                        .format(e, loss_value, val_loss_value))
                else:
                    logging.info(
                        "===> Epoch: {} \t loss: {:.3f} \t val_loss: {:.3f}".
                        format(e, loss_value, val_loss_value))

        if self.early_stopping:
            self.saver.restore_weights(self.session)
        else:
            self.saver.save_weights(self.session)

        tend = time.time()
        fitting_time = tend - t0
        self.last_fit_duration_ = fitting_time

        if log_time:
            logging.info(
                "[TripletEncoder fitting time]: {} minutes and {} seconds".
                format(fitting_time // 60, int(fitting_time % 60)))
        return self.history

    def transform(self, X):
        output_var = self.forward_pass(self.placeholders['anchor'])
        output = output_var.eval(feed_dict={self.placeholders['anchor']: X},
                                 session=self.session)
        return output

    def persist(self, fpath):
        data = self.get_persist_info()
        if os.path.dirname(fpath) != "":
            if not os.path.exists(os.path.dirname(fpath)):
                os.path.makedirs(os.path.dirname(fpath))
        np.save(fpath, data)

    def serialize(self, fpath):
        self.persist(fpath)

    def get_persist_info(self):
        signature_data = {
            'input_dim': self.input_dim,
            'layer_sizes': self.layer_sizes,
            'activations': self.activations,
            'alpha': self.alpha,
            'learning_rate': self.learning_rate,
            'batch_size': self.batch_size,
            'n_epochs': self.n_epochs,
            'early_stopping': self.early_stopping,
            'patience': self.patience,
            'random_state': self.random_state,
            'v1_compat_mode': self.v1_compat_mode
        }
        other_data = {
            'best_weights': self.saver.best_params,  # ws and bs
            'history': self.history,
            'best_epoch': self.best_epoch_,
            'last_fit_duration': self.last_fit_duration_,
            'centroids': self.centroids,
            'altered_centroids': self.altered_centroids
        }
        return {'signature': signature_data, 'other': other_data}

    def clone(self):
        data = self.get_persist_info()
        return TEncoder.make_instance(data['signature'], data['other'])

    @staticmethod
    def make_instance(signature_data, other_data):
        instance = TEncoder(**signature_data)
        instance.compile()
        instance.saver.best_params = other_data['best_weights'].copy()
        instance.saver.restore_weights(instance.session)
        instance.history = other_data['history'].copy()
        instance.last_fit_duration_ = other_data['last_fit_duration']
        instance.best_epoch_ = other_data['best_epoch']
        if 'centroids' in other_data:
            instance.centroids = other_data['centroids']
            instance.altered_centroids = other_data['altered_centroids']
        return instance

    @staticmethod
    def load_from_file(fpath):
        data = np.load(fpath, allow_pickle=True)[()]
        return TEncoder.make_instance(data['signature'], data['other'])
Beispiel #2
0
class CAEPlus:
    def __init__(self,
                 input_dim,
                 layer_sizes,
                 activations,
                 lamda=1e-1,
                 learning_rate=0.001,
                 batch_size=128,
                 n_epochs=100,
                 config_vec_size=10,
                 gamma=1e-1,
                 early_stopping=True,
                 patience=10,
                 random_state=42):
        """
        Implements a contractive auto-encoder

        layer_sizes: list
            It concerns only the encoder part and not the decoder part.
        lamda: term that multiplies the jacobian term of the loss function.
        gamma: term that multiplies the configuration approximation term of the
               loss function
        """
        self.input_dim = input_dim
        self.layer_sizes = layer_sizes
        self.activations = activations
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.early_stopping = early_stopping
        self.patience = patience
        self.random_state = random_state
        self.session = tf.compat.v1.Session()
        self.best_epoch_ = None
        self.last_fit_duration_ = None
        self.centroids = None
        self.altered_centroids = None
        self.lamda = lamda
        self.config_vec_size = config_vec_size  # size of the configuration vector
        self.gamma = gamma  # multiplier of the configuration approx

        assert len(set(activations)) == 1
        assert len(layer_sizes) <= 2
        assert len(activations) == len(layer_sizes)

    def get_jacobian_loss(self, iv_encodings):
        encodings = iv_encodings
        if len(self.layer_sizes) == 1 and self.activations[0] == 'sigmoid':
            w = self.weights[0]
            w = w[:, :-self.config_vec_size]
            w_sum_over_input_dim = tf.reduce_sum(tf.square(w), axis=0)
            w_ = tf.expand_dims(w_sum_over_input_dim, 1)
            h_ = tf.square(encodings * (1 - encodings))
            h_times_w_ = tf.matmul(h_, w_)
            jacobian = tf.reduce_mean(h_times_w_)

        elif len(self.layer_sizes) == 1 and self.activations[0] == 'relu':
            w = self.weights[0]
            b = self.biases[0]
            w = w[:, :-self.config_vec_size]
            b = b[:-self.config_vec_size]
            pre_activation = tf.matmul(self.input_pl, w) + b
            indicator = tf.nn.relu(tf.sign(pre_activation))
            w_s = tf.square(w)
            w_ = tf.transpose(tf.reduce_sum(w_s, axis=0, keepdims=True))
            batch_jacobian_vec = tf.matmul(indicator, w_)
            jacobian = tf.reduce_mean(batch_jacobian_vec)

        elif len(self.layer_sizes) == 2 and self.activations[0] == 'sigmoid':
            w1_var = self.weights[0]
            w2_var = self.weights[1]
            w2_var = w2_var[:, :-self.config_vec_size]
            b1_var = self.biases[0]
            x_pl = self.input_pl
            intermediate = tf.nn.sigmoid(tf.matmul(x_pl, w1_var) + b1_var)
            z_ = intermediate * (1 - intermediate)
            aux = tf.expand_dims(z_, 2) * w2_var
            k_sum = tf.matmul(w1_var, aux)
            k_ss = tf.square(k_sum)
            sum_k_ss = tf.reduce_sum(k_ss, axis=1)
            h_ = tf.square(encodings * (1 - encodings))
            batch_jacobian_vec = tf.reduce_sum(h_ * sum_k_ss, axis=1)
            jacobian = tf.reduce_mean(batch_jacobian_vec)
        elif len(self.layer_sizes) == 2 and self.activations[0] == 'relu':
            x_pl = self.input_pl
            w1_var = self.weights[0]
            w2_var = self.weights[1]
            w2_var = w2_var[:, :-self.config_vec_size]
            b1_var = self.biases[0]
            b2_var = self.biases[1]
            b2_var = b2_var[:-self.config_vec_size]
            preac_1 = tf.matmul(x_pl, w1_var) + b1_var
            intermediate = tf.nn.relu(preac_1)
            indicator_1 = tf.nn.relu(tf.sign(preac_1))
            preac_2 = tf.matmul(intermediate, w2_var) + b2_var
            indicator_2 = tf.nn.relu(tf.sign(preac_2))
            z_ = indicator_1
            aux = tf.expand_dims(z_, 2) * w2_var
            k_sum = tf.matmul(w1_var, aux)
            k_ss = tf.square(k_sum)
            sum_k_ss = tf.reduce_sum(k_ss, axis=1)
            h_ = indicator_2
            batch_jacobian_vec = tf.reduce_sum(h_ * sum_k_ss, axis=1)
            jacobian = tf.reduce_mean(batch_jacobian_vec)
        else:
            raise NotImplementedError(
                "Jacobian not yet implemented for this activation: {}".format(
                    self.activations[0]))
        return jacobian

    def compile(self):
        self._init_architecture()
        obs_approx = self.full_forward_pass(self.input_pl)
        encodings = self.forward_pass(self.input_pl)

        iv_encodings = encodings[:, :-self.config_vec_size]
        v_encodings = encodings[:, -self.config_vec_size:]

        recons_loss = tf.reduce_mean(tf.square(obs_approx - self.input_pl))
        jacobian_loss = self.get_jacobian_loss(iv_encodings)
        config_approx_loss = tf.reduce_mean(
            tf.square(v_encodings - self.config_pl))

        loss = recons_loss

        if self.lamda > 1e-9:
            loss += self.lamda * jacobian_loss
        if self.gamma > 1e-9:
            loss += self.gamma * config_approx_loss

        train_step = tf.compat.v1.train.AdamOptimizer(
            self.learning_rate).minimize(loss)

        self.loss = loss
        self.recons_loss = recons_loss
        self.jacobian_loss = jacobian_loss
        self.config_approx_loss = config_approx_loss
        self.train_step = train_step

    def forward_pass(self, input_pl):
        output = input_pl
        for i in range(len(self.weights)):
            output = tf.matmul(output, self.weights[i]) + self.biases[i]
            if self.activations[i] == 'relu':
                output = tf.nn.relu(output)
            elif self.activations[i] == 'sigmoid':
                output = tf.nn.sigmoid(output)
            elif self.activations[i] == '' or self.activations[i] is None:
                pass
            else:
                raise NotImplementedError(
                    "This activation ({}) is not yet implemented.".format(
                        self.activations[i]))
        return output

    def full_forward_pass(self, input_pl):
        encoding = self.forward_pass(input_pl)
        output = encoding
        for i in range(len(self.decoder_weights)):
            output = tf.matmul(
                output, self.decoder_weights[i]) + self.decoder_biases[i]
            if self.activations[len(self.activations) - i - 1] == 'relu':
                output = tf.nn.relu(output)
        return output

    def _init_architecture(self):
        tf.compat.v1.disable_eager_execution()
        self.input_pl = tf.compat.v1.placeholder(tf.float32,
                                                 shape=(None, self.input_dim))
        self.config_pl = tf.compat.v1.placeholder(tf.float32,
                                                  shape=(None,
                                                         self.config_vec_size))

        weights = []
        biases = []
        i_dim = self.input_dim

        for layer_size in self.layer_sizes:
            w = weight_variable([i_dim, layer_size])
            b = bias_variable([layer_size])
            i_dim = layer_size
            weights.append(w)
            biases.append(b)

        decoder_weights = []
        decoder_biases = []
        for w in weights[::-1]:
            decoder_weights.append(tf.transpose(w))
            decoder_biases.append(bias_variable([int(w.shape[0])]))

        self.weights = weights
        self.biases = biases
        self.decoder_weights = decoder_weights
        self.decoder_biases = decoder_biases

        self.saver = Saver(self.weights + self.biases + self.decoder_biases)

    def get_fd(self, X, config=None):
        if config is None:
            return {
                self.input_pl: X,
            }
        else:
            return {self.input_pl: X, self.config_pl: config}

    def eval_var(self, var, X, config=None):
        return var.eval(feed_dict=self.get_fd(X, config=config),
                        session=self.session)

    def log_losses(self, X, config, val=False, verbose=False, e=0):
        recons_loss = self.eval_var(self.recons_loss, X)
        jacobian_loss = self.eval_var(self.jacobian_loss, X)
        config_approx_loss = self.eval_var(self.config_approx_loss, X, config)
        loss = recons_loss
        if self.lamda > 1e-9:
            loss += self.lamda * jacobian_loss
        if self.gamma > 1e-9:
            loss += self.gamma * config_approx_loss

        if not val:
            self.history['loss'].append(loss)
            self.history['recons_loss'].append(recons_loss)
            self.history['jacobian_loss'].append(jacobian_loss)
            self.history['config_approx_loss'].append(config_approx_loss)
        else:
            self.history['val_loss'].append(loss)
            self.history['val_recons_loss'].append(recons_loss)
            self.history['val_jacobian_loss'].append(jacobian_loss)
            self.history['val_config_approx_loss'].append(config_approx_loss)

        if verbose:
            if val:
                prefix = "[VAL]"
            else:
                prefix = "[TRAIN]"
            logging.info(
                "{} Epoch {} - Losses: recons: {:.5f} \t jacobi: {:.5f} \config_approx: {:.5f} \t total: {:.5f}"
                .format(prefix, e, recons_loss, jacobian_loss,
                        config_approx_loss, loss))
        return loss

    def fit(self, X, config, log_time=False, verbose=False):
        t0 = time.time()

        self.history = {
            'loss': [],
            'recons_loss': [],
            'jacobian_loss': [],
            'config_approx_loss': []
        }
        if self.early_stopping:
            X, X_val, config, config_val = train_test_split(X,
                                                            config,
                                                            shuffle=False)
            self.history['val_loss'] = []
            self.history['val_recons_loss'] = []
            self.history['val_jacobian_loss'] = []
            self.history['val_config_approx_loss'] = []

        n_points = len(X)
        sess = self.session
        sess.run(tf.compat.v1.global_variables_initializer())
        self.log_losses(X, config, verbose=verbose)

        if self.early_stopping:
            self.log_losses(X_val, config_val, val=True, verbose=verbose)

        n_batches = int(np.ceil(n_points / self.batch_size))
        bs = self.batch_size

        best_epoch = -1
        min_err = np.inf

        logging.info("Initial loss(es): {}".format(self.history))
        for e in range(self.n_epochs):
            if self.early_stopping and best_epoch > 0 and e > best_epoch + self.patience:
                exited_early_stopping = True
                break

            X, config = shuffle(X, config, random_state=self.random_state + e)
            for i in range(n_batches):
                x_batch = X[i * bs:(i + 1) * bs, :]
                conf_batch = config[i * bs:(i + 1) * bs, :]
                self.train_step.run(feed_dict=self.get_fd(x_batch,
                                                          config=conf_batch),
                                    session=self.session)
            loss_value = self.log_losses(X, config, verbose=verbose, e=e)
            if self.early_stopping:
                val_loss_value = self.log_losses(X_val,
                                                 config_val,
                                                 val=True,
                                                 verbose=verbose,
                                                 e=e)
                if val_loss_value < min_err:
                    min_err = val_loss_value
                    best_epoch = e
                    self.best_epoch_ = e
                    self.saver.save_weights(self.session)
                    if verbose:
                        logging.info(
                            "===> Epoch: {} \t loss: {:.6f} \t val_loss: {:.6f} ** (new best epoch)"
                            .format(e, loss_value, val_loss_value))

        if self.early_stopping:
            self.saver.restore_weights(self.session)
        else:
            self.saver.save_weights(self.session)

        tend = time.time()
        fitting_time = tend - t0
        self.last_fit_duration_ = fitting_time

        if log_time:
            logging.info(
                "[autoencoder fitting time]: {} minutes and {} seconds".format(
                    fitting_time // 60, int(fitting_time % 60)))
        return self.history

    def transform(self, X, keep_config_dimensions=False):
        output_var = self.forward_pass(self.input_pl)
        output = output_var.eval(feed_dict={self.input_pl: X},
                                 session=self.session)
        if keep_config_dimensions:
            return output

        return output[:, :-self.config_vec_size]

    def persist(self, fpath):
        data = self.get_persist_info()
        if os.path.dirname(fpath) != "":
            if not os.path.exists(os.path.dirname(fpath)):
                os.path.makedirs(os.path.dirname(fpath))
        np.save(fpath, data)

    def serialize(self, fpath):
        self.persist(fpath)

    def get_persist_info(self):
        signature_data = {
            'input_dim': self.input_dim,
            'layer_sizes': self.layer_sizes,
            'activations': self.activations,
            'learning_rate': self.learning_rate,
            'batch_size': self.batch_size,
            'n_epochs': self.n_epochs,
            'early_stopping': self.early_stopping,
            'patience': self.patience,
            'random_state': self.random_state,
            'lamda': self.lamda
        }
        other_data = {
            'best_weights': self.saver.best_params,  # ws and bs
            'history': self.history,
            'best_epoch': self.best_epoch_,
            'last_fit_duration': self.last_fit_duration_,
            'centroids': self.centroids,
            'altered_centroids': self.altered_centroids
        }
        return {'signature': signature_data, 'other': other_data}

    def clone(self):
        data = self.get_persist_info()
        return CAEPlus.make_instance(data['signature'], data['other'])

    @staticmethod
    def make_instance(signature_data, other_data):
        instance = CAEPlus(**signature_data)
        instance.compile()
        instance.saver.best_params = other_data['best_weights'].copy()
        instance.saver.restore_weights(instance.session)
        instance.history = other_data['history'].copy()
        instance.last_fit_duration_ = other_data['last_fit_duration']
        instance.best_epoch_ = other_data['best_epoch']
        if 'centroids' in other_data:
            instance.centroids = other_data['centroids']
            instance.altered_centroids = other_data['altered_centroids']
        return instance

    @staticmethod
    def load_from_file(fpath):
        data = np.load(fpath, allow_pickle=True)[()]
        return CAEPlus.make_instance(data['signature'], data['other'])

    @staticmethod
    def build(input_dim=561,
              encoding_dim=5,
              depth=2,
              nh=20,
              activation='sigmoid',
              learning_rate=1e-3,
              batch_size=32,
              n_epochs=500,
              random_state=10,
              early_stopping=False,
              patience=10,
              lamda=1e-1,
              gamma=1e-1,
              config_vec_size=10):
        """
        Provides another interface (other than the constructor) for
        constructing autoencoder objects...
        """
        encoder_hidden_layers = [int(nh / (2**i)) for i in range(depth - 1)]
        if len(encoder_hidden_layers) > 0:
            if 0 in encoder_hidden_layers or encoder_hidden_layers[
                    -1] < encoding_dim:
                return None
        hidden_layer_sizes = encoder_hidden_layers + [encoding_dim]
        activations = [activation] * depth

        ae = CAEPlus(input_dim,
                     hidden_layer_sizes,
                     activations,
                     lamda=lamda,
                     gamma=gamma,
                     learning_rate=learning_rate,
                     batch_size=batch_size,
                     n_epochs=n_epochs,
                     early_stopping=early_stopping,
                     patience=patience,
                     random_state=random_state,
                     config_vec_size=config_vec_size)
        return ae

    @staticmethod
    def valid_params(ae_params, encoding_size):
        nh = ae_params['nh']
        depth = ae_params['depth']
        if depth >= 2:
            return (nh / (2**(depth - 2))) > encoding_size
        return True
Beispiel #3
0
class TripletPlusPlus:
    def __init__(self, input_dim, layer_sizes, activations, alpha=0.01,
                 gamma=1e-1, lamda=1e-1, config_vec_size=10,
                 learning_rate=0.001, batch_size=128, n_epochs=40,
                 early_stopping=True, patience=5, random_state=42,
                 epsilon=1e-5, v1_compat_mode=False,
                 weight_init='random'):
        """
        Siamese Neural network with triplet-loss implemented and featuring
        a reconstruction term for (1) anchor, (2) positive (3) negative as well
        as a configuration approximation term for (1) anchor, (2) positive and (3)
        negative.


        The layer_sizes is only defined for the encoder part and not
        the decoder part...

        gamma: coefficient that multiplies the sum of the 3 reconstruction terms.
        lamda: coefficient that multiplies the sum of the 3 config approx terms.
        alpha: margin used within the triplet loss...
        config_vec_size: int
            size of the configuration vector (or in other words, the number of knobs)

        epsilon: float
            Epsilon used in Adam optimizer for numerical stability

        """
        self.input_dim = input_dim
        self.layer_sizes = layer_sizes
        self.activations = activations
        self.alpha = alpha
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.early_stopping = early_stopping
        self.patience = patience
        self.random_state = random_state
        if v1_compat_mode:
            self.session = tf.compat.v1.Session()
        else:
            self.session = tf.Session()
        self.best_epoch_ = None
        self.last_fit_duration_ = None
        self.centroids = None
        self.altered_centroids = None
        self.gamma = gamma
        self.lamda = lamda
        self.config_vec_size = config_vec_size
        self.weight_init = weight_init
        self.epsilon = epsilon
        self.v1_compat_mode = v1_compat_mode

    def compile(self):
        if self.v1_compat_mode:
            tf.compat.v1.disable_eager_execution()
        self._init_architecture()
        anchor_encoding = self.forward_pass(self.placeholders['anchor'])
        pos_encoding = self.forward_pass(self.placeholders['pos'])
        neg_encoding = self.forward_pass(self.placeholders['neg'])

        anchor_encoding_iv = anchor_encoding[:, :-self.config_vec_size]
        anchor_encoding_v = anchor_encoding[:, -self.config_vec_size:]

        pos_encoding_iv = pos_encoding[:, :-self.config_vec_size]
        pos_encoding_v = pos_encoding[:, -self.config_vec_size:]

        neg_encoding_iv = neg_encoding[:, :-self.config_vec_size]
        neg_encoding_v = neg_encoding[:, -self.config_vec_size:]

        ap_norm = tf.norm(
            tf.square(anchor_encoding_iv - pos_encoding_iv),
            keepdims=True, axis=1)
        an_norm = tf.norm(
            tf.square(anchor_encoding_iv - neg_encoding_iv),
            keepdims=True, axis=1)
        triplet_loss = tf.nn.relu(ap_norm - an_norm + self.alpha)
        triplet_loss = tf.reduce_mean(triplet_loss)

        anchor_output = self.full_forward_pass(self.placeholders['anchor'])
        pos_output = self.full_forward_pass(self.placeholders['pos'])
        neg_output = self.full_forward_pass(self.placeholders['neg'])
        anchor_recons_loss = tf.reduce_mean(
            tf.square(anchor_output - self.placeholders['anchor']))
        pos_recons_loss = tf.reduce_mean(
            tf.square(pos_output - self.placeholders['pos']))
        neg_recons_loss = tf.reduce_mean(
            tf.square(neg_output - self.placeholders['neg']))

        recons_loss = anchor_recons_loss + pos_recons_loss + neg_recons_loss

        anchor_config_loss = tf.reduce_mean(
            tf.square(anchor_encoding_v - self.placeholders['config_anchor']))

        pos_config_loss = tf.reduce_mean(
            tf.square(pos_encoding_v - self.placeholders['config_pos']))

        neg_config_loss = tf.reduce_mean(
            tf.square(neg_encoding_v - self.placeholders['config_neg']))

        config_loss = anchor_config_loss + pos_config_loss + neg_config_loss

        loss = triplet_loss
        if self.gamma > 1e-9:
            loss += self.gamma * recons_loss
        if self.lamda > 1e-9:
            loss += self.lamda * config_loss

        if self.v1_compat_mode:
            adam = tf.compat.v1.train.AdamOptimizer
        else:
            adam = tf.train.AdamOptimizer

        train_step = adam(self.learning_rate,
                          epsilon=self.epsilon).minimize(loss)

        self.loss = loss
        self.train_step = train_step

    def forward_pass(self, input_pl):
        output = input_pl
        for i in range(len(self.weights)):
            output = tf.matmul(output, self.weights[i]) + self.biases[i]
            if self.activations[i] == 'relu':
                output = tf.nn.relu(output)
            elif self.activations[i] == 'sigmoid':
                output = tf.nn.sigmoid(output)
            elif self.activations[i] == '' or self.activations[i] is None:
                pass
            else:
                raise NotImplementedError(
                    "This activation ({}) is not yet implemented.".format(
                        self.activations[i]))
        return output

    def full_forward_pass(self, input_pl):
        encoding = self.forward_pass(input_pl)
        output = encoding
        logging.debug("Full forward pass: "******"{}: prev_output_shape: {} \t weights shape: {} \t bias shape: {}".format(
                i, output.shape, self.decoder_weights[i].shape, self.decoder_biases[i].shape))
            output = tf.matmul(
                output, self.decoder_weights[i]) + self.decoder_biases[i]
            if self.activations[len(self.activations) - i - 1] == 'relu':
                output = tf.nn.relu(output)
        return output

    def _init_architecture(self):
        if self.v1_compat_mode:
            pl_fnc = tf.compat.v1.placeholder
        else:
            pl_fnc = tf.placeholder
        anchor_pl = pl_fnc(tf.float32, shape=(None, self.input_dim))
        pos_pl = pl_fnc(tf.float32, shape=(None, self.input_dim))
        neg_pl = pl_fnc(tf.float32, shape=(None, self.input_dim))
        config_anchor_pl = pl_fnc(
            tf.float32, shape=(None, self.config_vec_size))
        config_pos_pl = pl_fnc(
            tf.float32, shape=(None, self.config_vec_size))
        config_neg_pl = pl_fnc(
            tf.float32, shape=(None, self.config_vec_size))

        self.placeholders = {
            'anchor': anchor_pl,
            'pos': pos_pl,
            'neg': neg_pl,
            'config_anchor': config_anchor_pl,
            'config_pos': config_pos_pl,
            'config_neg': config_neg_pl
        }

        weights = []
        biases = []
        i_dim = self.input_dim
        # encoder
        for layer_size in self.layer_sizes:
            if self.weight_init == "glorot":
                r = np.sqrt(6./(i_dim + layer_size))
                w = weight_variable2(
                    [i_dim, layer_size],
                    r, v1_compat_mode=self.v1_compat_mode)
                b = bias_variable2([layer_size])
            else:
                w = weight_variable(
                    [i_dim, layer_size],
                    v1_compat_mode=self.v1_compat_mode)
                b = bias_variable([layer_size])
            i_dim = layer_size
            weights.append(w)
            biases.append(b)

        decoder_weights = []
        decoder_biases = []
        for w in weights[::-1]:
            decoder_weights.append(tf.transpose(w))
            if self.weight_init == "glorot":
                decoder_biases.append(bias_variable2([int(w.shape[0])]))
            else:
                decoder_biases.append(bias_variable([int(w.shape[0])]))

        self.weights = weights
        self.biases = biases
        self.decoder_weights = decoder_weights
        self.decoder_biases = decoder_biases

        self.saver = Saver(self.weights + self.biases + self.decoder_biases)

    def get_fd(self, X_a, X_p, X_n, C_a, C_p, C_n):
        return {self.placeholders['anchor']: X_a,
                self.placeholders['pos']: X_p,
                self.placeholders['neg']: X_n,
                self.placeholders['config_anchor']: C_a,
                self.placeholders['config_pos']: C_p,
                self.placeholders['config_neg']: C_n
                }

    def eval_var(self, var, X_a, X_p, X_n, C_a, C_p, C_n):
        return var.eval(
            feed_dict=self.get_fd(X_a, X_p, X_n, C_a, C_p, C_n),
            session=self.session)

    def fit_idxs(
            self, triplet_idxs, fetch_method, lods, log_time=False,
            verbose=False):
        t0 = time.time()

        triplet_idxs = np.array(triplet_idxs)

        if self.early_stopping:
            triplet_idxs, triplet_idxs_val = train_test_split(
                triplet_idxs, shuffle=False)
            self.history = {'loss': [], 'val_loss': []}
        else:
            self.history = {'loss': []}

        n_points = len(triplet_idxs)
        sess = self.session
        if self.v1_compat_mode:
            sess.run(tf.compat.v1.global_variables_initializer())
        else:
            sess.run(tf.global_variables_initializer())

        n_batches = int(np.ceil(n_points / self.batch_size))
        bs = self.batch_size

        best_epoch = -1
        min_err = np.inf

        for e in range(self.n_epochs):
            if self.early_stopping and best_epoch > 0 and e > best_epoch + self.patience:
                exited_early_stopping = True
                break

            triplet_idxs = shuffle(
                triplet_idxs, random_state=self.random_state + e)

            loss_value = []
            for i in range(n_batches):
                n_skipped = 0
                if (i % 1000) == 0:
                    if verbose:
                        logging.info(
                            "Epoch: {} \t step: {}/{} batches".format(e, i, n_batches))

                triplets_idxs_batch = triplet_idxs[i*bs:(i+1)*bs, :]
                xa_batch, xp_batch, xn_batch, ca_batch, cp_batch, cn_batch = fetch_method(
                    triplets_idxs_batch, lods)
                batch_loss_value = self.eval_var(
                    self.loss, xa_batch, xp_batch, xn_batch, ca_batch,
                    cp_batch, cn_batch)
                if np.isfinite(batch_loss_value):
                    self.train_step.run(
                        feed_dict=self.get_fd(
                            xa_batch, xp_batch, xn_batch, ca_batch, cp_batch,
                            cn_batch),
                        session=self.session)
                    loss_value.append(batch_loss_value)
                else:
                    n_skipped += 1
            loss_value = np.mean(loss_value)
            self.history['loss'].append(loss_value)
            if not np.isfinite(loss_value):
                logging.warn(
                    "Training stopped: nan or inf loss value")
                break
            if not self.early_stopping and verbose:
                logging.info("===> Epoch: {} \t loss: {:.6f}".format(
                    e, loss_value))
            else:
                n_val_batches = int(
                    np.ceil(len(triplet_idxs_val) / self.batch_size))
                val_loss_value = []
                for i in range(n_val_batches):
                    triplets_idxs_batch_val = triplet_idxs_val[i*bs:(i+1)*bs, :]
                    xa_batch_val, xp_batch_val, xn_batch_val, ca_batch_val, cp_batch_val, cn_batch_val = fetch_method(
                        triplets_idxs_batch_val, lods)
                    batch_loss_value = self.eval_var(
                        self.loss, xa_batch_val, xp_batch_val, xn_batch_val,
                        ca_batch_val, cp_batch_val, cn_batch_val)
                    if np.isfinite(batch_loss_value):
                        val_loss_value.append(batch_loss_value)
                val_loss_value = np.nanmean(val_loss_value)

                self.history['val_loss'].append(val_loss_value)
                if not np.isfinite(val_loss_value):
                    logging.warn(
                        "Training stopped: nan or inf validation loss value")
                    break
                if val_loss_value < min_err:
                    min_err = val_loss_value
                    best_epoch = e
                    self.best_epoch_ = e
                    self.saver.save_weights(self.session)
                    if verbose:
                        logging.info("===> Epoch: {} \t loss: {:.6f} \t val_loss: {:.6f} ** (new best epoch)".format(
                            e, loss_value, val_loss_value))
                else:
                    if verbose:
                        logging.info("===> Epoch: {} \t loss: {:.6f} \t val_loss: {:.6f}".format(
                            e, loss_value, val_loss_value))

        if self.early_stopping:
            self.saver.restore_weights(self.session)
        else:
            self.saver.save_weights(self.session)

        tend = time.time()
        fitting_time = tend - t0
        self.last_fit_duration_ = fitting_time

        if log_time:
            logging.info(
                "[Triplet++ fitting time]: {} minutes and {} seconds".format(
                    fitting_time // 60,
                    int(fitting_time % 60)))
        return self.history

    def fit(self, X_a, X_p, X_n, config_a, config_p, config_n, log_time=False):
        assert len(X_a) == len(X_p)
        assert len(X_p) == len(X_n)
        assert len(config_a) == len(config_p)
        assert len(config_a) == len(config_n)
        assert len(config_a) == len(X_a)

        t0 = time.time()

        if self.early_stopping:
            X_a, X_a_val, X_p, X_p_val, X_n, X_n_val, config_a, config_a_val,\
                config_p, config_p_val, config_n, config_n_val = train_test_split(
                    X_a, X_p, X_n, config_a, config_p, config_n, shuffle=False)
            self.history = {'loss': [], 'val_loss': []}
        else:
            self.history = {'loss': []}

        n_points = len(X_a)
        sess = self.session
        if self.v1_compat_mode:
            sess.run(tf.compat.v1.global_variables_initializer())
        else:
            sess.run(tf.global_variables_initializer())
        self.history['loss'].append(self.eval_var(
            self.loss, X_a, X_p, X_n, config_a, config_p, config_n))
        if self.early_stopping:
            self.history['val_loss'].append(
                self.eval_var(
                    self.loss, X_a_val, X_p_val, X_n_val, config_a_val,
                    config_p_val, config_n_val))
        n_batches = int(np.ceil(n_points / self.batch_size))
        bs = self.batch_size

        best_epoch = -1
        min_err = np.inf

        logging.info("Initial loss(es): {}".format(self.history))
        for e in range(self.n_epochs):
            if self.early_stopping and best_epoch > 0 and e > best_epoch + self.patience:
                exited_early_stopping = True
                break

            X_a, X_p, X_n, C_a, C_p, C_n = shuffle(
                X_a, X_p, X_n, config_a, config_p, config_n,
                random_state=self.random_state + e)
            for i in range(n_batches):
                if (i % 1000) == 0:
                    logging.info(
                        "Epoch: {} \t step: {}/{} batches".format(e, i, n_batches))
                xa_batch = X_a[i*bs:(i+1)*bs, :]
                xp_batch = X_p[i*bs:(i+1)*bs, :]
                xn_batch = X_n[i*bs:(i+1)*bs, :]
                ca_batch = C_a[i*bs:(i+1)*bs, :]
                cp_batch = C_p[i*bs:(i+1)*bs, :]
                cn_batch = C_n[i*bs:(i+1)*bs, :]
                self.train_step.run(feed_dict=self.get_fd(
                    xa_batch, xp_batch, xn_batch, ca_batch, cp_batch, cn_batch),
                    session=self.session)
            loss_value = self.eval_var(self.loss, X_a, X_p, X_n, C_a, C_p, C_n)
            if not np.isfinite(loss_value):
                logging.warn("Training stopped: nan or inf loss value")
                break
            self.history['loss'].append(loss_value)
            if not self.early_stopping:
                logging.info("===> Epoch: {} \t loss: {:.6f}".format(
                    e, loss_value))
            else:
                val_loss_value = self.eval_var(
                    self.loss, X_a_val, X_p_val, X_n_val, config_a_val,
                    config_p_val, config_n_val)
                self.history['val_loss'].append(val_loss_value)
                if not np.isfinite(val_loss_value):
                    logging.warn(
                        "Training stopped: nan or inf validation loss value")
                    break
                if val_loss_value < min_err:
                    min_err = val_loss_value
                    best_epoch = e
                    self.best_epoch_ = e
                    self.saver.save_weights(self.session)
                    logging.info("===> Epoch: {} \t loss: {:.6f} \t val_loss: {:.6f} ** (new best epoch)".format(
                        e, loss_value, val_loss_value))
                else:
                    logging.info("===> Epoch: {} \t loss: {:.6f} \t val_loss: {:.6f}".format(
                        e, loss_value, val_loss_value))

        if self.early_stopping:
            self.saver.restore_weights(self.session)
        else:
            self.saver.save_weights(self.session)

        tend = time.time()
        fitting_time = tend - t0
        self.last_fit_duration_ = fitting_time

        if log_time:
            logging.info(
                "[Triplet++ fitting time]: {} minutes and {} seconds".format(
                    fitting_time // 60,
                    int(fitting_time % 60)))
        return self.history

    def transform(self, X, keep_config_dimensions=False):
        output_var = self.forward_pass(self.placeholders['anchor'])
        output = output_var.eval(feed_dict={
            self.placeholders['anchor']: X
        }, session=self.session)
        if keep_config_dimensions:
            return output
        else:
            return output[:, :-self.config_vec_size]

    def persist(self, fpath):
        data = self.get_persist_info()
        if os.path.dirname(fpath) != "":
            if not os.path.exists(os.path.dirname(fpath)):
                os.path.makedirs(os.path.dirname(fpath))
        np.save(fpath, data)

    def serialize(self, fpath):
        self.persist(fpath)

    def get_persist_info(self):
        signature_data = {
            'input_dim': self.input_dim,
            'layer_sizes': self.layer_sizes,
            'activations': self.activations,
            'alpha': self.alpha,
            'learning_rate': self.learning_rate,
            'batch_size': self.batch_size,
            'n_epochs': self.n_epochs,
            'early_stopping': self.early_stopping,
            'patience': self.patience,
            'random_state': self.random_state,
            'gamma': self.gamma,
            'lamda': self.lamda,
            'config_vec_size': self.config_vec_size,
            'epsilon': self.epsilon,
            'v1_compat_mode': self.v1_compat_mode
        }
        other_data = {
            'best_weights': self.saver.best_params,  # ws and bs
            'history': self.history,
            'best_epoch': self.best_epoch_,
            'last_fit_duration': self.last_fit_duration_,
            'centroids': self.centroids,
            'altered_centroids': self.altered_centroids
        }
        return {'signature': signature_data,
                'other': other_data}

    def clone(self):
        data = self.get_persist_info()
        return TripletPlusPlus.make_instance(data['signature'], data['other'])

    @staticmethod
    def make_instance(signature_data, other_data):
        instance = TripletPlusPlus(**signature_data)
        instance.compile()
        instance.saver.best_params = other_data['best_weights'].copy()
        instance.saver.restore_weights(instance.session)
        instance.history = other_data['history'].copy()
        instance.last_fit_duration_ = other_data['last_fit_duration']
        instance.best_epoch_ = other_data['best_epoch']
        if 'centroids' in other_data:
            instance.centroids = other_data['centroids']
            instance.altered_centroids = other_data['altered_centroids']
        return instance

    @staticmethod
    def load_from_file(fpath):
        data = np.load(fpath, allow_pickle=True)[()]
        return TripletPlusPlus.make_instance(data['signature'],
                                             data['other'])
Beispiel #4
0
class SNNAE:
    def __init__(self,
                 input_dim,
                 layer_sizes,
                 activations,
                 lamda=1e-1,
                 T=2,
                 learning_rate=0.001,
                 batch_size=128,
                 n_epochs=100,
                 early_stopping=True,
                 patience=10,
                 random_state=42,
                 epsilon=1e-8):
        """
        Implements an auto-encoder with SNN loss

        layer_sizes: list
            It concerns only the encoder part and not the decoder part.
        lamda: term that multiplies the SNN term of the loss function.
        T: float, default=2
            Temperature term in the SNN loss

        epsilon: float, default=1e-8
            An Epsilon float number used for numerical stability
            (added to the log)

        """
        self.input_dim = input_dim
        self.layer_sizes = layer_sizes
        self.activations = activations
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.early_stopping = early_stopping
        self.patience = patience
        self.random_state = random_state
        self.session = tf.compat.v1.Session()
        self.best_epoch_ = None
        self.last_fit_duration_ = None
        self.centroids = None
        self.altered_centroids = None
        self.lamda = lamda
        self.epsilon = epsilon
        self.T = T

        assert len(activations) == len(layer_sizes)

    def get_log_at_i(self, x_batch, y_batch, i, selector):
        T = self.T
        b = self.batch_size
        ind_denom = selector
        denom = tf.reduce_sum(
            ind_denom *
            tf.exp(-tf.square(tf.norm(x_batch[i] - x_batch, axis=1)) / T))

        ind_num = tf.reshape(selector, [b, 1])
        indic_same_label = tf.cast(tf.equal(y_batch[i], y_batch),
                                   dtype=tf.float64)
        indic = tf.squeeze(tf.transpose(ind_num * indic_same_label))

        nume = tf.reduce_sum(
            indic *
            tf.exp(-tf.square(tf.norm(x_batch[i] - x_batch, axis=1)) / T))

        return tf.math.log(self.epsilon + nume / denom)

    def get_snn_loss(self, encoding_var, label_pl):
        elems = np.arange(self.batch_size)
        selectors = np.ones([self.batch_size, self.batch_size])
        for i in range(self.batch_size):
            selectors[i, i] = 0

        elems = np.arange(self.batch_size)
        logs_batch = tf.map_fn(
            lambda t: self.get_log_at_i(encoding_var, label_pl, t[0], t[1]),
            (elems, selectors),
            dtype=tf.float64)
        snn = -tf.reduce_mean(logs_batch)

        return snn

    def compile(self):
        self._init_architecture()
        obs_approx = self.full_forward_pass(self.input_pl)
        encodings = self.forward_pass(self.input_pl)
        labels = self.label_pl

        recons_loss = tf.reduce_mean(tf.square(obs_approx - self.input_pl))
        snn_loss = self.get_snn_loss(encodings, labels)

        if self.lamda < 1e-9:
            loss = recons_loss
        else:
            loss = recons_loss + self.lamda * snn_loss

        train_step = tf.compat.v1.train.AdamOptimizer(
            self.learning_rate).minimize(loss)

        self.loss = loss
        self.recons_loss = recons_loss
        self.snn_loss = snn_loss
        self.train_step = train_step

    def forward_pass(self, input_pl):
        output = input_pl
        for i in range(len(self.weights)):
            output = tf.matmul(output, self.weights[i]) + self.biases[i]
            if self.activations[i] == 'relu':
                output = tf.nn.relu(output)
            elif self.activations[i] == 'sigmoid':
                output = tf.nn.sigmoid(output)
            elif self.activations[i] == '' or self.activations[i] is None:
                pass
            else:
                raise NotImplementedError(
                    "This activation ({}) is not yet implemented.".format(
                        self.activations[i]))
        return output

    def full_forward_pass(self, input_pl):
        encoding = self.forward_pass(input_pl)
        output = encoding
        for i in range(len(self.decoder_weights)):
            output = tf.matmul(
                output, self.decoder_weights[i]) + self.decoder_biases[i]
            if self.activations[len(self.activations) - i - 1] == 'relu':
                output = tf.nn.relu(output)
        return output

    def _init_architecture(self):
        tf.compat.v1.disable_eager_execution()
        self.input_pl = tf.compat.v1.placeholder(tf.float64,
                                                 shape=(None, self.input_dim))
        self.label_pl = tf.compat.v1.placeholder(tf.int32, shape=(None, 1))

        weights = []
        biases = []
        i_dim = self.input_dim
        for layer_size in self.layer_sizes:
            w = weight_variable([i_dim, layer_size])
            b = bias_variable([layer_size])
            i_dim = layer_size
            weights.append(w)
            biases.append(b)

        decoder_weights = []
        decoder_biases = []
        for w in weights[::-1]:
            decoder_weights.append(tf.transpose(w))
            decoder_biases.append(bias_variable([int(w.shape[0])]))

        self.weights = weights
        self.biases = biases
        self.decoder_weights = decoder_weights
        self.decoder_biases = decoder_biases

        self.saver = Saver(self.weights + self.biases + self.decoder_biases)

    def get_fd(self, X, y=None):
        if y is None:
            return {self.input_pl: X}
        else:
            return {self.input_pl: X, self.label_pl: y}

    def eval_var(self, var, X):
        return var.eval(feed_dict=self.get_fd(X), session=self.session)

    def eval_var2(self, var, X, y):
        return var.eval(feed_dict=self.get_fd(X, y), session=self.session)

    def eval_snn(self, X, y, batches):
        snn_loss = []
        for i in range(len(batches)):
            idxs = batches[i]
            x_batch = X[idxs, :]
            y_batch = y[idxs, :]
            batch_snn_loss = self.eval_var2(self.snn_loss, x_batch, y_batch)
            if np.isfinite(batch_snn_loss):
                snn_loss.append(batch_snn_loss)
        snn_loss = np.mean(snn_loss)
        return snn_loss

    def log_losses(self, X, y, batches, val=False, verbose=False, e=0):
        recons_loss = self.eval_var(self.recons_loss, X)
        snn_loss = self.eval_snn(X, y, batches)
        if self.lamda < 1e-9:
            loss = recons_loss
        else:
            loss = recons_loss + self.lamda * snn_loss

        if not val:
            self.history['loss'].append(loss)
            self.history['recons_loss'].append(recons_loss)
            self.history['snn_loss'].append(snn_loss)
        else:
            self.history['val_loss'].append(loss)
            self.history['val_recons_loss'].append(recons_loss)
            self.history['val_snn_loss'].append(snn_loss)

        if verbose:
            if val:
                prefix = "[VAL]"
            else:
                prefix = "[TRAIN]"
            logging.info(
                "{} Epoch {} - Losses: recons: {:.5f} \t snn: {:.5f} \t total: {:.5f}"
                .format(prefix, e, recons_loss, snn_loss, loss))
        return loss

    def log_losses(self, X, y, batches, val=False, verbose=False, e=0):
        recons_loss = self.eval_var(self.recons_loss, X)
        snn_loss = self.eval_snn(X, y, batches)
        if self.lamda < 1e-9:
            loss = recons_loss
        else:
            loss = recons_loss + self.lamda * snn_loss

        if not val:
            self.history['loss'].append(loss)
            self.history['recons_loss'].append(recons_loss)
            self.history['snn_loss'].append(snn_loss)
        else:
            self.history['val_loss'].append(loss)
            self.history['val_recons_loss'].append(recons_loss)
            self.history['val_snn_loss'].append(snn_loss)

        if verbose:
            if val:
                prefix = "[VAL]"
            else:
                prefix = "[TRAIN]"
            logging.info(
                "{} Epoch {} - Losses: recons: {:.5f} \t second_term: {:.5f} \t total: {:.5f}"
                .format(prefix, e, recons_loss, snn_loss, loss))
        return loss

    def get_batches(self, slices, n_batches):
        batches = []
        tmp = deepcopy(slices)

        for i in range(n_batches):
            current_batch_idxs = []
            keys = list(tmp.keys())
            keys = shuffle(keys)

            for key in keys:
                if len(tmp[key]) < 4:
                    keys.remove(key)

            for key in keys:
                if len(tmp[key]) >= 4:
                    current_batch_idxs.extend(tmp[key][:4])
                    del tmp[key][:4]
                if len(current_batch_idxs) == self.batch_size:
                    break
            if len(current_batch_idxs) == self.batch_size:
                batches.append(current_batch_idxs)
        return batches

    def fit(self, X, y, log_time=False, verbose=False):
        t0 = time.time()

        self.history = {'loss': [], 'recons_loss': [], 'snn_loss': []}

        if self.early_stopping:
            repeat = True
            while repeat:
                X, X_val, y, y_val = train_test_split(X, y, shuffle=False)
                self.history['val_loss'] = []
                self.history['val_recons_loss'] = []
                self.history['val_snn_loss'] = []

                labels_to_rows = {}
                for i in range(len(y)):
                    if int(y[i]) in labels_to_rows:
                        labels_to_rows[int(y[i])].append(i)
                    else:
                        labels_to_rows[int(y[i])] = [i]

                labels_to_rows_val = {}
                for i in range(len(y_val)):
                    if int(y_val[i]) in labels_to_rows_val:
                        labels_to_rows_val[int(y_val[i])].append(i)
                    else:
                        labels_to_rows_val[int(y_val[i])] = [i]
                repeat = False

                for key in labels_to_rows:
                    if len(labels_to_rows[key]) < 2:
                        repeat = True
                        break
                for key in labels_to_rows_val:
                    if len(labels_to_rows_val[key]) < 2:
                        repeat = True
                        break
            n_batches_val = int(np.ceil(len(X_val) / self.batch_size)) - 1
        else:
            labels_to_rows = {}
            for i in range(len(y)):
                if y[i] in labels_to_rows:
                    labels_to_rows[y[i]].append(i)
                else:
                    labels_to_rows[y[i]] = [i]

        n_points = len(X)
        n_batches = int(np.ceil(n_points / self.batch_size)) - 1

        sess = self.session
        sess.run(tf.compat.v1.global_variables_initializer())

        batches = self.get_batches(labels_to_rows, n_batches)
        batches_val = self.get_batches(labels_to_rows_val, n_batches_val)

        loss_value = self.log_losses(X, y, batches, verbose=verbose)

        if not np.isfinite(loss_value):
            logging.warn(
                "Training not started because of non finite loss value: {}".
                format(loss_value))
            return None

        if self.early_stopping:
            val_loss_value = self.log_losses(X_val,
                                             y_val,
                                             batches_val,
                                             val=True,
                                             verbose=verbose)
            if not np.isfinite(val_loss_value):
                logging.warn(
                    "Training cancelled because of non finite val loss value: {}"
                    .format(val_loss_value))
                return None

        best_epoch = -1
        min_err = np.inf

        logging.info("Initial loss(es): {}".format(self.history))
        for e in range(self.n_epochs):
            if self.early_stopping and best_epoch > 0 and e > best_epoch + self.patience:
                exited_early_stopping = True
                break

            batches = self.get_batches(labels_to_rows, n_batches)
            batches_val = self.get_batches(labels_to_rows_val, n_batches_val)

            for i in range(len(batches)):
                idxs = batches[i]
                x_batch = X[idxs, :]
                y_batch = y[idxs, :]
                self.train_step.run(feed_dict=self.get_fd(x_batch, y_batch),
                                    session=self.session)
            loss_value = self.log_losses(X, y, batches, verbose=verbose, e=e)
            if not np.isfinite(loss_value):
                logging.warn(
                    "Training stopped after {} epochs because of non finite loss value ({})"
                    .format(e, loss_value))
                break
            if self.early_stopping:
                val_loss_value = self.log_losses(X_val,
                                                 y_val,
                                                 batches_val,
                                                 val=True,
                                                 verbose=verbose,
                                                 e=e)
                if val_loss_value < min_err:
                    min_err = val_loss_value
                    best_epoch = e
                    self.best_epoch_ = e
                    self.saver.save_weights(self.session)
                    if verbose:
                        logging.info(
                            "===> Epoch: {} \t loss: {:.6f} \t val_loss: {:.6f} ** (new best epoch)"
                            .format(e, loss_value, val_loss_value))

        if self.early_stopping:
            self.saver.restore_weights(self.session)
        else:
            self.saver.save_weights(self.session)

        tend = time.time()
        fitting_time = tend - t0
        self.last_fit_duration_ = fitting_time

        if log_time:
            logging.info(
                "[autoencoder fitting time]: {} minutes and {} seconds".format(
                    fitting_time // 60, int(fitting_time % 60)))
        return self.history

    def transform(self, X):
        output_var = self.forward_pass(self.input_pl)
        output = output_var.eval(feed_dict={self.input_pl: X},
                                 session=self.session)
        return output

    def persist(self, fpath):
        data = self.get_persist_info()
        if os.path.dirname(fpath) != "":
            if not os.path.exists(os.path.dirname(fpath)):
                os.path.makedirs(os.path.dirname(fpath))
        np.save(fpath, data)

    def serialize(self, fpath):
        self.persist(fpath)

    def get_persist_info(self):
        signature_data = {
            'input_dim': self.input_dim,
            'layer_sizes': self.layer_sizes,
            'activations': self.activations,
            'learning_rate': self.learning_rate,
            'batch_size': self.batch_size,
            'n_epochs': self.n_epochs,
            'early_stopping': self.early_stopping,
            'patience': self.patience,
            'random_state': self.random_state,
            'lamda': self.lamda,
            'T': self.T,
            'epsilon': self.epsilon
        }
        other_data = {
            'best_weights': self.saver.best_params,  # ws and bs
            'history': self.history,
            'best_epoch': self.best_epoch_,
            'last_fit_duration': self.last_fit_duration_,
            'centroids': self.centroids,
            'altered_centroids': self.altered_centroids
        }
        return {'signature': signature_data, 'other': other_data}

    def clone(self):
        data = self.get_persist_info()
        return SNNAE.make_instance(data['signature'], data['other'])

    @staticmethod
    def make_instance(signature_data, other_data):
        instance = SNNAE(**signature_data)
        instance.compile()
        instance.saver.best_params = other_data['best_weights'].copy()
        instance.saver.restore_weights(instance.session)
        instance.history = other_data['history'].copy()
        instance.last_fit_duration_ = other_data['last_fit_duration']
        instance.best_epoch_ = other_data['best_epoch']

        return instance

    @staticmethod
    def load_from_file(fpath):
        data = np.load(fpath, allow_pickle=True)[()]
        return SNNAE.make_instance(data['signature'], data['other'])

    @staticmethod
    def build(input_dim=561,
              T=2,
              encoding_dim=5,
              depth=2,
              nh=20,
              activation='sigmoid',
              learning_rate=1e-3,
              batch_size=32,
              n_epochs=500,
              random_state=10,
              early_stopping=False,
              patience=10,
              lamda=1e-1,
              epsilon=1e-8):
        """
        Provides another interface (other than the constructor) for
        constructing autoencoder objects...
        """
        encoder_hidden_layers = [int(nh / (2**i)) for i in range(depth - 1)]
        if len(encoder_hidden_layers) > 0:
            if 0 in encoder_hidden_layers or encoder_hidden_layers[
                    -1] < encoding_dim:
                return None
        hidden_layer_sizes = encoder_hidden_layers + [encoding_dim]
        activations = [activation] * depth

        ae = SNNAE(input_dim,
                   hidden_layer_sizes,
                   activations,
                   lamda=lamda,
                   T=T,
                   learning_rate=learning_rate,
                   batch_size=batch_size,
                   n_epochs=n_epochs,
                   early_stopping=early_stopping,
                   patience=patience,
                   random_state=random_state,
                   epsilon=epsilon)
        return ae

    @staticmethod
    def valid_params(ae_params, encoding_size):
        if ae_params['patience'] >= ae_params['n_epochs']:
            return False
        nh = ae_params['nh']
        depth = ae_params['depth']
        if depth >= 2:
            return (nh / (2**(depth - 2))) > encoding_size
        return True