예제 #1
0
    def test_hash_metric(self, activate_k, k_set):
        self.logger.info("Model testing k hash starts")
        self.logger.info("Activation k(={}) in buckets(={})".format(
            activate_k, self.args.d))

        self.regen_session()
        test_k_activate = activate_k_2D(self.test_k_hash,
                                        k=activate_k,
                                        session=self.sess)  # [ntest, args.d]
        if not hasattr(self, 'te_te_distance'):
            self.regen_session()
            self.te_te_distance = pairwise_distance_euclid_efficient(
                input1=self.test_embed,
                input2=self.test_embed,
                session=self.sess,
                batch_size=128)
            self.logger.info(
                "Calculating pairwise distance from test embeddings")

        performance = evaluate_hash_te(test_hash_key=test_k_activate, te_te_distance=self.te_te_distance,\
                                          te_te_query_key=test_k_activate, te_te_query_value=self.test_k_hash,\
                                          test_label=self.test_label, ncls_test=self.ncls_test,\
                                          activate_k=activate_k, k_set=k_set, logger=self.logger)

        self.logger.info("Model testing k hash ends")
        return performance
예제 #2
0
    def test_vq(self, activate_k, k_set):
        self.logger.info("Model testing vq starts")
        self.logger.info("Activation k(={}) in buckets(={})".format(
            activate_k, self.args.m))
        if not hasattr(self, 'test_kmc'):
            self.regen_session()
            self.test_kmc = KMeansClustering(self.test_embed, self.args.m)
        if not hasattr(self, 'te_te_distance'):
            self.regen_session()
            self.te_te_distance = pairwise_distance_euclid_efficient(
                input1=self.test_embed,
                input2=self.test_embed,
                session=self.sess,
                batch_size=128)
            self.logger.info(
                "Calculating pairwise distance from test embeddings")

        te_te_query_value = self.test_kmc.k_hash(
            self.test_embed, self.sess)  # [ntest, args.d] center test
        te_te_query_key = activate_k_2D(te_te_query_value,
                                        k=activate_k,
                                        session=self.sess)  # [ntest, args.d]
        test_hash_key = te_te_query_key
        self.regen_session()

        performance = evaluate_hash_te(test_hash_key=test_hash_key, te_te_distance=self.te_te_distance,\
                                          te_te_query_key=te_te_query_key, te_te_query_value=te_te_query_value,\
                                          test_label=self.test_label, ncls_test=self.ncls_test,\
                                          activate_k=activate_k, k_set=k_set, logger=self.logger)

        self.logger.info("Model testing vq ends")
        return performance
예제 #3
0
 def initialize(self):
     '''Initialize uninitialized variables'''
     self.logger.info("Model initialization starts")
     self.generate_sess()
     rest_initializer(self.sess)
     self.start_epoch = 0
     val_p_dist = pairwise_distance_euclid_efficient(
         input1=self.val_embed,
         input2=self.val_embed,
         session=self.sess,
         batch_size=self.args.nbatch)
     self.logger.info("Calculating pairwise distance of validation data")
     self.val_arg_sort = np.argsort(val_p_dist, axis=1)
     self.logger.info("Model initialization ends")
예제 #4
0
 def test_metric(self, k_set):
     self.logger.info("Model testing metric starts")
     if not hasattr(self, 'te_te_distance'):
         self.regen_session()
         self.te_te_distance = pairwise_distance_euclid_efficient(
             input1=self.test_embed,
             input2=self.test_embed,
             session=self.sess,
             batch_size=128)
         self.logger.info(
             "Calculating pairwise distance from test embeddings")
     performance = evaluate_metric_te(test_label=self.test_label,
                                      te_te_distance=self.te_te_distance,
                                      k_set=k_set,
                                      logger=self.logger)
     return performance