def _update_model_prediction_metrics(metrics: _MetricMap, split: str, label_id: tf.Tensor, prediction: tf.Tensor): """Updates metrics related to model prediction quality.""" # Updates clustering related metrics. metrics['{}/adjusted_mutual_info'.format(split)].update_state( utils.adjusted_mutual_info(label_id, prediction)) metrics['{}/cluster_purity'.format(split)].update_state( utils.cluster_purity(label_id, prediction)) prediction_classes, _ = tf.unique(tf.reshape(prediction, shape=[-1])) metrics['{}/unique_prediction_class_count'.format(split)].update_state( tf.size(prediction_classes)) # Updates accuracies. metrics['{}/accuracy'.format(split)].update_state(label_id, prediction, tf.sign(label_id)) class_balanced_weight = utils.create_rebalanced_sample_weights(label_id) metrics['{}/class_balanced_accuracy'.format(split)].update_state( label_id, prediction, class_balanced_weight)
def evaluate_model(model, generator, len_data_val, x, modelpath, epochs, batch_size, latent_dim, num_clusters, learning_rate, alpha, gamma, theta, epochs_pretrain, decay_factor, ex_name, data_set, validation, dropout, prior_var, prior): """Evaluates the performance of the trained model in terms of normalized mutual information adjusted mutual information score and purity. Args: model (VarPSOM): Trained VarPSOM model to evaluate. generator (generator): Data generator for the batches. len_data_val (int): Length of validation set. x (tf.Tensor): Input tensor or placeholder. modelpath (path): Path from which to restore the model. epochs (int): number of epochs of training. batch_size (int): Batch size for the training. latent_dim (int): Dimensionality of the VarIDEC's latent space. num_clusters (int): Number of clusters. learning_rate (float): Learning rate for the optimization. alpha (float): Student's t-distribution parameter. gamma (float): Weight for the KL term of the VarIDEC clustering loss. theta (float): Weight for the VAE loss. epochs_pretrain (int): Number of VAE pretraining epochs. decay_factor (float): Factor for the learning rate decay. ex_name (string): Unique name of this particular run. data_set (string): Data set for the training. validation (bool): If "True" validation set is used for evaluation, otherwise test set is used. dropout (float): Dropout factor for the feed-forward layers of the VAE. prior_var (float): Multiplier of the diagonal variance of the VAE multivariate gaussian prior. prior (float): Weight of the regularization term of the ELBO. Returns: dict: Dictionary of evaluation results (NMI, AMI, Purity). """ saver = tf.train.Saver() num_batches = len_data_val // batch_size with tf.Session() as sess: saver.restore(sess, modelpath) graph = tf.get_default_graph() z = graph.get_tensor_by_name("reconstruction_e/decoder/z_e:0") is_training = model.is_training if validation: val_gen = generator("val", batch_size) else: val_gen = generator("test", batch_size) test_k_all = [] labels_val_all = [] print("Evaluation...") for i in range(num_batches): batch_data, batch_labels, ii = next(val_gen) labels_val_all.extend(batch_labels) test_k =, feed_dict={ x: batch_data, is_training: True, z: np.zeros((batch_size, latent_dim)) }) test_k_all.extend(test_k) test_nmi = metrics.normalized_mutual_info_score( np.array(labels_val_all), test_k_all) test_purity = cluster_purity(np.array(test_k_all), np.array(labels_val_all)) test_ami = metrics.adjusted_mutual_info_score(test_k_all, labels_val_all) results = {} results["NMI"] = test_nmi results["Purity"] = test_purity results["AMI"] = test_ami if np.abs(test_ami - 0.) < 0.0001 and np.abs(test_nmi - 0.125) < 0.0001: return None if data_set == "fMNIST": f = open("results_fMNIST_VarIDEC.txt", "a+") else: f = open("results_MNIST_VarIDEC.txt", "a+") f.write( "Epochs= %d, num_clusters=%d, latent_dim= %d, batch_size= %d, learning_rate= %f, gamma=%d, " "theta=%f, alpha=%f, dropout=%f, decay_factor=%f, prior_var=%f, prior=%f, epochs_pretrain=%d" % (epochs, num_clusters, latent_dim, batch_size, learning_rate, gamma, theta, alpha, dropout, decay_factor, prior_var, prior, epochs_pretrain)) f.write(", RESULTS NMI: %f, AMI: %f, Purity: %f. Name: %r \n" % (results["NMI"], results["AMI"], results["Purity"], ex_name)) f.close() return results
def test_cluster_purity(self): a = tf.constant([[1, 0, 0], [1, 1, 0]]) b = tf.constant([[1, 2, 3], [1, 1, 2]]) self.assertEqual(utils.cluster_purity(a, b), 1.)