コード例 #1
0
    def __init__(self, train_dataset, val_dataset, test_dataset, logfilepath,
                 args):
        self.args = args

        selectGpuById(self.args.gpu)
        self.logfilepath = logfilepath
        self.logger = LoggerManager(self.logfilepath, __name__)

        self.dataset_dict = dict()
        self.set_train_dataset(train_dataset)
        self.set_val_dataset(val_dataset)
        self.set_test_dataset(test_dataset)
コード例 #2
0
 def switch_log_path(self, logfilepath):
     self.logger.remove()
     print("Log file switched from {} to {}".format(self.logfilepath,
                                                    logfilepath))
     self.logfilepath = logfilepath
     self.logger = LoggerManager(self.logfilepath, __name__)
コード例 #3
0
class DeepMetric:
    def __init__(self, train_dataset, val_dataset, test_dataset, logfilepath,
                 args):
        self.args = args

        selectGpuById(self.args.gpu)
        self.logfilepath = logfilepath
        self.logger = LoggerManager(self.logfilepath, __name__)

        self.dataset_dict = dict()
        self.set_train_dataset(train_dataset)
        self.set_val_dataset(val_dataset)
        self.set_test_dataset(test_dataset)

    def set_train_dataset(self, train_dataset):
        self.logger.info("Setting train_dataset starts")
        self.train_dataset = train_dataset
        self.dataset_dict['train'] = self.train_dataset
        self.train_image = self.dataset_dict['train'].image
        self.train_label = self.dataset_dict['train'].label
        self.ntrain, self.height, self.width, self.nchannel = self.train_dataset.image.shape
        self.ncls_train = self.train_dataset.nclass
        self.nbatch_train = self.ntrain // self.args.nbatch
        self.logger.info("Setting train_dataset ends")

    def set_test_dataset(self, test_dataset):
        self.logger.info("Setting test_dataset starts")
        self.test_dataset = test_dataset
        self.dataset_dict['test'] = self.test_dataset
        self.test_image = self.dataset_dict['test'].image
        self.test_label = self.dataset_dict['test'].label
        self.ntest = self.test_dataset.ndata
        self.ncls_test = self.test_dataset.nclass
        self.nbatch_test = self.ntest // self.args.nbatch
        self.logger.info("Setting test_dataset ends")

    def set_val_dataset(self, val_dataset):
        self.logger.info("Setting val_dataset starts")
        self.val_dataset = val_dataset
        self.dataset_dict['val'] = self.val_dataset
        self.val_image = self.dataset_dict['val'].image
        self.val_label = self.dataset_dict['val'].label
        self.nval = self.val_dataset.ndata
        self.nbatch_val = self.nval // self.args.nbatch
        self.logger.info("Setting val_dataset ends")

    def switch_log_path(self, logfilepath):
        self.logger.remove()
        print("Log file switched from {} to {}".format(self.logfilepath,
                                                       logfilepath))
        self.logfilepath = logfilepath
        self.logger = LoggerManager(self.logfilepath, __name__)

    def build(self):
        self.logger.info("Model building starts")
        tf.reset_default_graph()
        if self.args.ltype == 'npair':
            self.anc_img = tf.placeholder(tf.float32,
                                          shape=[
                                              self.args.nbatch // 2,
                                              self.height, self.width,
                                              self.nchannel
                                          ])
            self.pos_img = tf.placeholder(tf.float32,
                                          shape=[
                                              self.args.nbatch // 2,
                                              self.height, self.width,
                                              self.nchannel
                                          ])
            self.istrain = tf.placeholder(tf.bool, shape=[])
            self.label = tf.placeholder(tf.int32,
                                        shape=[self.args.nbatch // 2])
        else:  # triplet
            self.img = tf.placeholder(tf.float32,
                                      shape=[
                                          self.args.nbatch, self.height,
                                          self.width, self.nchannel
                                      ])
            self.istrain = tf.placeholder(tf.bool, shape=[])
            self.label = tf.placeholder(tf.int32, shape=[self.args.nbatch])

        self.generate_sess()

        self.conv_net = CONV_DICT[self.args.dataset][self.args.conv]

        if self.args.ltype == 'npair':
            self.anc_last, _ = self.conv_net(self.anc_img,
                                             is_training=self.istrain,
                                             reuse=False)
            self.pos_last, _ = self.conv_net(self.pos_img,
                                             is_training=self.istrain,
                                             reuse=True)
            self.anc_last = tf.nn.relu(self.anc_last)
            self.pos_last = tf.nn.relu(self.pos_last)
        else:  #triplet
            self.last, _ = self.conv_net(self.img,
                                         is_training=self.istrain,
                                         reuse=False)
            self.last = tf.nn.relu(self.last)

        with slim.arg_scope([slim.fully_connected],
                            activation_fn=None,
                            weights_regularizer=slim.l2_regularizer(0.0005),
                            biases_initializer=tf.zeros_initializer()):
            if self.args.ltype == 'npair':
                with tf.variable_scope('Embed', reuse=False):
                    self.anc_embed = slim.fully_connected(self.anc_last,
                                                          self.args.m,
                                                          scope="fc1")
                with tf.variable_scope('Embed', reuse=True):
                    self.pos_embed = slim.fully_connected(self.pos_last,
                                                          self.args.m,
                                                          scope="fc1")
                self.loss = npairs_loss(labels=self.label,
                                        embeddings_anchor=self.anc_embed,
                                        embeddings_positive=self.pos_embed,
                                        reg_lambda=self.args.lamb)
            else:  #triplet
                with tf.variable_scope('Embed', reuse=False):
                    self.embed = slim.fully_connected(self.last,
                                                      self.args.m,
                                                      scope="fc1")
                self.embed_l2_norm = tf.nn.l2_normalize(
                    self.embed, dim=-1)  # embedding with l2 normalization

                def pairwise_distance_c(embeddings):
                    return pairwise_distance_euclid(embeddings, squared=True)

                self.loss = triplet_semihard_loss(
                    labels=self.label,
                    embeddings=self.embed_l2_norm,
                    pairwise_distance=pairwise_distance_c,
                    margin=self.args.ma)
        self.loss += tf.losses.get_regularization_loss()

        initialized_variables = get_initialized_vars(self.sess)
        self.logger.info("Variables loaded from pretrained network\n{}".format(
            vars_info_vl(initialized_variables)))
        self.logger.info("Model building ends")

    def set_up_train(self, pretrain=False):
        self.logger.info("Model setting up train starts")

        decay_func = DECAY_DICT[self.args.dtype]
        if hasattr(self, 'start_epoch'):
            self.logger.info("Current start epoch : {}".format(
                self.start_epoch))
            DECAY_PARAMS_DICT[self.args.hdtype][self.args.nbatch][
                self.args.
                hdptype]['initial_step'] = self.nbatch_train * self.start_epoch
        self.lr, update_step_op = decay_func(**DECAY_PARAMS_DICT[
            self.args.dtype][self.args.nbatch][self.args.dptype])

        print(vars_info_vl(tf.trainable_variables()))
        update_ops = tf.get_collection("update_ops")

        with tf.control_dependencies(update_ops + [update_step_op]):
            self.train_op = get_multi_train_op(tf.train.AdamOptimizer,
                                               self.loss, [self.lr],
                                               [tf.trainable_variables()])

        self.graph_ops_dict = {
            'train': [self.train_op, self.loss],
            'val': self.loss,
            'test': self.loss
        }
        self.val_embed_tensor1 = tf.placeholder(
            tf.float32, shape=[self.args.nbatch, self.args.m])
        self.val_embed_tensor2 = tf.placeholder(tf.float32,
                                                shape=[self.nval, self.args.m])

        self.p_dist = math_ops.add(
                    math_ops.reduce_sum(math_ops.square(self.val_embed_tensor1), axis=[1], keep_dims=True),
                    math_ops.reduce_sum(math_ops.square(array_ops.transpose(self.val_embed_tensor2)), axis=[0], keep_dims=True))-\
                2.0 * math_ops.matmul(self.val_embed_tensor1, array_ops.transpose(self.val_embed_tensor2)) # [batch_size, 1], [1, ndata],  [batch_size, ndata]

        self.p_dist = math_ops.maximum(self.p_dist, 0.0)  # [batch_size, ndata]
        self.p_dist = math_ops.multiply(
            self.p_dist,
            math_ops.to_float(
                math_ops.logical_not(math_ops.less_equal(self.p_dist, 0.0))))
        self.p_max_idx = tf.nn.top_k(
            -self.p_dist, k=2)[1]  # [batch_size, 2] # get smallest 2

        self.logger.info("Model setting up train ends")

    def generate_sess(self):
        try:
            self.sess
        except AttributeError:
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(config=config)

    def initialize(self):
        '''Initialize uninitialized variables'''
        self.logger.info("Model initialization starts")
        self.generate_sess()
        rest_initializer(self.sess)
        self.start_epoch = 0
        self.logger.info("Model initialization ends")

    def save(self, global_step, save_dir):
        self.logger.info("Model save starts")
        for f in glob.glob(save_dir + '*'):
            os.remove(f)
        saver = tf.train.Saver(max_to_keep=5)
        saver.save(self.sess,
                   os.path.join(save_dir, 'model'),
                   global_step=global_step)
        self.logger.info("Model save in %s" % save_dir)
        self.logger.info("Model save ends")

    def restore(self, save_dir):
        """Restore all variables in graph with the latest version"""
        self.logger.info("Restoring model starts...")
        saver = tf.train.Saver()
        latest_checkpoint = tf.train.latest_checkpoint(save_dir)
        self.logger.info("Restoring from {}".format(latest_checkpoint))
        self.start_epoch = int(
            os.path.basename(latest_checkpoint)[len('model') + 1:])
        self.generate_sess()
        saver.restore(self.sess, latest_checkpoint)
        self.logger.info("Restoring model done.")

    def run_batch(self, key='train'):
        '''
        self.args :
            key - string
                train, test, val
        Return : 
            following graph operations
        '''
        assert key in ['train', 'test',
                       'val'], "key should be train or val or test"
        if self.args.ltype == 'npair':
            batch_anc_img, batch_pos_img, batch_anc_label, batch_pos_label = self.dataset_dict[
                key].next_batch(batch_size=self.args.nbatch)

            feed_dict = {
                        self.anc_img : batch_anc_img,\
                        self.pos_img : batch_pos_img,\
                        self.label : batch_anc_label,\
                        self.istrain : True if key in ['train'] else False
                        }
            return self.sess.run(self.graph_ops_dict[key], feed_dict=feed_dict)
        else:  # triplet
            batch_img, batch_label = self.dataset_dict[key].next_batch(
                batch_size=self.args.nbatch)
            feed_dict = {
                        self.img : batch_img,\
                        self.label : batch_label,\
                        self.istrain : True if key in ['train'] else False
                         }
            return self.sess.run(self.graph_ops_dict[key], feed_dict=feed_dict)

    def train(self, epoch, save_dir, board_dir):
        self.logger.info("Model training starts")

        self.train_writer = SummaryWriter(board_dir + 'train')
        self.val_writer = SummaryWriter(board_dir + 'val')

        max_val_recall = -1
        self.logger.info("Current epoch : {}/{}".format(
            self.start_epoch, epoch))
        self.logger.info("Current lr : {}".format(self.sess.run(self.lr)))

        if self.args.ltype == 'npair':

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.anc_img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch // 2,
                                   dim=4,
                                   train_gate=self.istrain)
        else:  # triplet

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch,
                                   dim=4,
                                   train_gate=self.istrain)

        for epoch_ in range(self.start_epoch, self.start_epoch + epoch):
            train_epoch_loss = 0
            for _ in tqdm(range(self.nbatch_train), ascii=True, desc="batch"):
                _, batch_loss = self.run_batch(key='train')
                train_epoch_loss += batch_loss
            # averaging
            train_epoch_loss /= self.nbatch_train

            if self.args.ltype == 'npair':
                self.val_embed = custom_apply_tf_op(inputs=self.val_image,
                                                    output_gate=self.anc_embed)
            else:
                self.val_embed = custom_apply_tf_op(
                    inputs=self.val_image, output_gate=self.embed)  # triplet

            val_p1 = get_recall_at_1_efficient(
                            data=self.val_embed, label=self.val_label,\
                            input1_tensor=self.val_embed_tensor1, input2_tensor=self.val_embed_tensor2,\
                            idx_tensor=self.p_max_idx, session=self.sess)

            self.logger.info(
                "Epoch({}/{}) train loss = {} val p@1 = {}".format(
                    epoch_ + 1, epoch + self.start_epoch, train_epoch_loss,
                    val_p1))
            if train_epoch_loss != train_epoch_loss: break  # nan

            self.train_writer.add_summary("loss", train_epoch_loss, epoch_)
            self.train_writer.add_summary("learning rate",
                                          self.sess.run(self.lr), epoch_)
            self.val_writer.add_summary("p@1", val_p1, epoch_)

            if epoch_ == self.start_epoch or max_val_recall < val_p1:
                max_val_recall = val_p1
                self.save(epoch_ + 1, save_dir)

        self.logger.info("Model training ends")

    def regen_session(self):
        tf.reset_default_graph()
        self.sess.close()
        self.sess = tf.Session()

    def prepare_test(self):
        self.logger.info("Model preparing test")
        self.test_image = self.dataset_dict['test'].image
        self.test_label = self.dataset_dict['test'].label
        self.train_image = self.dataset_dict['train'].image
        self.train_label = self.dataset_dict['train'].label

        if self.args.ltype == 'npair':

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.anc_img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch // 2,
                                   dim=4,
                                   train_gate=self.istrain)

            self.test_embed = custom_apply_tf_op(inputs=self.test_image,
                                                 output_gate=self.anc_embed)
            self.train_embed = custom_apply_tf_op(inputs=self.train_image,
                                                  output_gate=self.anc_embed)

        else:  # triplet

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch,
                                   dim=4,
                                   train_gate=self.istrain)

            self.test_embed = custom_apply_tf_op(inputs=self.test_image,
                                                 output_gate=self.embed)
            self.train_embed = custom_apply_tf_op(inputs=self.train_image,
                                                  output_gate=self.embed)

    def test_metric(self, k_set):
        self.logger.info("Model testing metric starts")
        if not hasattr(self, 'te_tr_distance') and not hasattr(
                self, 'te_te_distance'):
            self.regen_session()
            self.te_tr_distance = self.sess.run(
                pairwise_distance_euclid_v2(
                    tf.convert_to_tensor(self.test_embed, dtype=tf.float32),
                    tf.convert_to_tensor(self.train_embed, dtype=tf.float32)))
            self.te_te_distance = self.sess.run(
                pairwise_distance_euclid(
                    tf.convert_to_tensor(self.test_embed, dtype=tf.float32)))
        performance = evaluate_metric_te_tr(test_label=self.test_label,
                                            train_label=self.train_label,
                                            te_te_distance=self.te_te_distance,
                                            te_tr_distance=self.te_tr_distance,
                                            k_set=k_set,
                                            logger=self.logger)
        self.regen_session()
        return performance

    def test_th(self, activate_k, k_set):
        self.logger.info("Model testing thresholding starts")
        self.logger.info("Activation k(={}) in embeddings(={})".format(
            activate_k, self.args.m))
        test_k_activate = activate_k_2D(self.test_embed,
                                        k=activate_k,
                                        session=self.sess)  # [ntest, args.m]
        train_k_activate = activate_k_2D(self.train_embed,
                                         k=activate_k,
                                         session=self.sess)  # [ntrain, args.m]
        self.regen_session()

        if not hasattr(self, 'te_tr_distance') and not hasattr(
                self, 'te_te_distance'):
            self.regen_session()
            self.te_tr_distance = self.sess.run(
                pairwise_distance_euclid_v2(
                    tf.convert_to_tensor(self.test_embed, dtype=tf.float32),
                    tf.convert_to_tensor(self.train_embed, dtype=tf.float32)))
            self.te_te_distance = self.sess.run(
                pairwise_distance_euclid(
                    tf.convert_to_tensor(self.test_embed, dtype=tf.float32)))

        performance = evaluate_hash_te_tr(train_hash_key=train_k_activate, test_hash_key=test_k_activate,\
                                        te_tr_distance=self.te_tr_distance, te_te_distance=self.te_te_distance,\
                                        te_tr_query_key=test_k_activate, te_tr_query_value=self.test_embed,\
                                        te_te_query_key=test_k_activate, te_te_query_value=self.test_embed,\
                                        train_label=self.train_label, test_label=self.test_label,\
                                        ncls_train=self.ncls_train, ncls_test=self.ncls_test,\
                                        activate_k=activate_k, k_set=k_set, logger=self.logger)

        self.logger.info("Model testing thresholding ends")
        return performance

    def test_vq(self, activate_k, k_set):
        self.logger.info("Model testing vq starts")
        self.logger.info("Activation k(={}) in buckets(={})".format(
            activate_k, self.args.m))
        if not hasattr(self, 'train_kmc') and not hasattr(self, 'test_kmc'):
            self.test_kmc = KMeansClustering(self.test_embed, self.args.m)
            self.train_kmc = KMeansClustering(self.train_embed, self.args.m)
            self.regen_session()
        if not hasattr(self, 'te_tr_distance') and not hasattr(
                self, 'te_te_distance'):
            self.te_tr_distance = self.sess.run(
                pairwise_distance_euclid_v2(
                    tf.convert_to_tensor(self.test_embed, dtype=tf.float32),
                    tf.convert_to_tensor(self.train_embed, dtype=tf.float32)))
            self.te_te_distance = self.sess.run(
                pairwise_distance_euclid(
                    tf.convert_to_tensor(self.test_embed, dtype=tf.float32)))
            self.regen_session()

        te_te_query_value = self.test_kmc.k_hash(
            self.test_embed, self.sess)  # [ntest, args.m] center test
        te_tr_query_value = self.train_kmc.k_hash(
            self.test_embed, self.sess)  # [ntest, args.m] center train
        self.regen_session()

        te_te_query_key = activate_k_2D(te_te_query_value,
                                        k=activate_k,
                                        session=self.sess)  # [ntest, args.m]
        test_hash_key = te_te_query_key

        te_tr_query_key = activate_k_2D(te_tr_query_value,
                                        k=activate_k,
                                        session=self.sess)  # [ntest, args.m]
        train_hash_key = activate_k_2D(self.train_kmc.k_hash(
            self.train_embed, self.sess),
                                       k=activate_k,
                                       session=self.sess)  # [ntrain, args.m]
        self.regen_session()

        performance = evaluate_hash_te_tr(train_hash_key=train_hash_key, test_hash_key=test_hash_key,\
                                        te_tr_distance=self.te_tr_distance, te_te_distance=self.te_te_distance,\
                                        te_tr_query_key=te_tr_query_key, te_tr_query_value=te_tr_query_value,\
                                        te_te_query_key=te_te_query_key, te_te_query_value=te_te_query_value,\
                                        train_label=self.train_label, test_label=self.test_label,\
                                        ncls_train=self.ncls_train, ncls_test=self.ncls_test,\
                                        activate_k=activate_k, k_set=k_set, logger=self.logger)

        self.logger.info("Model testing vq ends")
        return performance

    def delete(self):
        tf.reset_default_graph()
        self.logger.remove()
        del self.logger
コード例 #4
0
class DeepMetric:
    def __init__(self, train_dataset, val_dataset, test_dataset, logfilepath,
                 args):
        self.args = args

        selectGpuById(self.args.gpu)
        self.logfilepath = logfilepath
        self.logger = LoggerManager(self.logfilepath, __name__)

        self.dataset_dict = dict()
        self.set_train_dataset(train_dataset)
        self.set_val_dataset(val_dataset)
        self.set_test_dataset(test_dataset)

    def set_train_dataset(self, train_dataset):
        self.logger.info("Setting train_dataset starts")
        self.train_dataset = train_dataset
        self.dataset_dict['train'] = self.train_dataset
        self.train_image = self.dataset_dict['train'].image
        self.train_label = self.dataset_dict['train'].label
        self.ntrain, self.height, self.width, self.nchannel = self.train_image.shape
        self.ncls_train = self.train_dataset.nclass
        self.nbatch_train = self.ntrain // self.args.nbatch
        self.logger.info("Setting train_dataset ends")

    def set_test_dataset(self, test_dataset):
        self.logger.info("Setting test_dataset starts")
        self.test_dataset = test_dataset
        self.dataset_dict['test'] = self.test_dataset
        self.test_image = self.dataset_dict['test'].image
        self.test_label = self.dataset_dict['test'].label
        self.ncls_test = self.test_dataset.nclass
        self.ntest = self.test_dataset.ndata
        self.nbatch_test = self.ntest // self.args.nbatch
        self.logger.info("Setting test_dataset ends")

    def set_val_dataset(self, val_dataset):
        self.logger.info("Setting val_dataset starts")
        self.val_dataset = val_dataset
        self.dataset_dict['val'] = self.val_dataset
        self.val_image = self.dataset_dict['val'].image
        self.val_label = self.dataset_dict['val'].label
        self.ncls_val = self.val_dataset.nclass
        self.nval = self.val_dataset.ndata
        self.nbatch_val = self.nval // self.args.nbatch
        self.logger.info("Setting val_dataset ends")

    def switch_log_path(self, logfilepath):
        self.logger.remove()
        print("Log file switched from {} to {}".format(self.logfilepath,
                                                       logfilepath))
        self.logfilepath = logfilepath
        self.logger = LoggerManager(self.logfilepath, __name__)

    def build(self, pretrain=False):
        self.logger.info("Model building starts")
        tf.reset_default_graph()
        if self.args.hltype == 'npair':
            self.anc_img = tf.placeholder(tf.float32,
                                          shape=[
                                              self.args.nbatch // 2,
                                              self.height, self.width,
                                              self.nchannel
                                          ])
            self.pos_img = tf.placeholder(tf.float32,
                                          shape=[
                                              self.args.nbatch // 2,
                                              self.height, self.width,
                                              self.nchannel
                                          ])
            self.istrain = tf.placeholder(tf.bool, shape=[])
            self.label = tf.placeholder(tf.int32,
                                        shape=[self.args.nbatch // 2])
        else:  # triplet
            self.img = tf.placeholder(tf.float32,
                                      shape=[
                                          self.args.nbatch, self.height,
                                          self.width, self.nchannel
                                      ])
            self.istrain = tf.placeholder(tf.bool, shape=[])
            self.label = tf.placeholder(tf.int32, shape=[self.args.nbatch])

        self.generate_sess()

        self.conv_net = CONV_DICT[self.args.dataset][self.args.conv]

        if self.args.hltype == 'npair':
            self.anc_last, _ = self.conv_net(self.anc_img,
                                             is_training=self.istrain,
                                             reuse=False)
            self.pos_last, _ = self.conv_net(self.pos_img,
                                             is_training=self.istrain,
                                             reuse=True)
            self.anc_last = tf.nn.relu(self.anc_last)
            self.pos_last = tf.nn.relu(self.pos_last)
        else:
            self.last, _ = self.conv_net(self.img,
                                         is_training=self.istrain,
                                         reuse=False)
            self.last = tf.nn.relu(self.last)

        with slim.arg_scope(
            [slim.fully_connected],
                activation_fn=None,
                weights_regularizer=slim.l2_regularizer(0.0005),
                biases_initializer=tf.zeros_initializer(),
                weights_initializer=tf.truncated_normal_initializer(0.0,
                                                                    0.01)):
            if self.args.hltype == 'npair':
                with tf.variable_scope('Embed', reuse=False):
                    self.anc_embed = slim.fully_connected(self.anc_last,
                                                          self.args.m,
                                                          scope="fc1")
                with tf.variable_scope('Embed', reuse=True):
                    self.pos_embed = slim.fully_connected(self.pos_last,
                                                          self.args.m,
                                                          scope="fc1")
            else:  #triplet
                with tf.variable_scope('Embed', reuse=False):
                    self.embed = slim.fully_connected(self.last,
                                                      self.args.m,
                                                      scope="fc1")

        initialized_variables = get_initialized_vars(self.sess)
        self.logger.info("Variables loaded from pretrained network\n{}".format(
            vars_info_vl(initialized_variables)))
        self.logger.info("Model building ends")

    def build_hash(self):
        self.logger.info("Model building train hash starts")

        self.mcf = SolveMaxMatching(nworkers=self.args.nsclass,
                                    ntasks=self.args.d,
                                    k=self.args.k,
                                    pairwise_lamb=self.args.plamb)

        if self.args.hltype == 'triplet':
            self.objective = tf.placeholder(
                tf.float32, shape=[self.args.nbatch, self.args.d])
        else:
            self.objective = tf.placeholder(
                tf.float32, shape=[self.args.nbatch // 2, self.args.d])

        with slim.arg_scope(
            [slim.fully_connected],
                activation_fn=None,
                weights_regularizer=slim.l2_regularizer(0.0005),
                biases_initializer=tf.zeros_initializer(),
                weights_initializer=tf.truncated_normal_initializer(0.0,
                                                                    0.01)):
            if self.args.hltype == 'triplet':
                self.embed_k_hash = self.last
                with tf.variable_scope('Hash', reuse=False):
                    self.embed_k_hash = slim.fully_connected(self.embed_k_hash,
                                                             self.args.d,
                                                             scope="fc1")
                self.embed_k_hash_l2_norm = tf.nn.l2_normalize(
                    self.embed_k_hash,
                    dim=-1)  # embedding with l2 normalization
                self.pairwise_distance = PAIRWISE_DISTANCE_WITH_OBJECTIVE_DICT[
                    self.args.hdt]
                self.loss_hash = triplet_semihard_loss_hash(labels=self.label, embeddings=self.embed_k_hash_l2_norm, objectives=self.objective,\
                                pairwise_distance=self.pairwise_distance, margin=self.args.hma)
            else:
                self.anc_embed_k_hash = self.anc_last
                self.pos_embed_k_hash = self.pos_last
                with tf.variable_scope('Hash', reuse=False):
                    self.anc_embed_k_hash = slim.fully_connected(
                        self.anc_embed_k_hash, self.args.d, scope="fc1")
                with tf.variable_scope('Hash', reuse=True):
                    self.pos_embed_k_hash = slim.fully_connected(
                        self.pos_embed_k_hash, self.args.d, scope="fc1")
                self.similarity_func = PAIRWISE_SIMILARITY_WITH_OBJECTIVE_DICT[
                    self.args.hdt]
                self.loss_hash = npairs_loss_hash(labels=self.label, embeddings_anchor=self.anc_embed_k_hash, embeddings_positive=self.pos_embed_k_hash,\
                        objective=self.objective, similarity_func=self.similarity_func, reg_lambda=self.args.hlamb)
        self.logger.info("Model building train hash ends")

    def set_up_train_hash(self):
        self.logger.info("Model setting up train hash starts")

        decay_func = DECAY_DICT[self.args.hdtype]
        if hasattr(self, 'start_epoch'):
            self.logger.info("Current start epoch : {}".format(
                self.start_epoch))
            DECAY_PARAMS_DICT[self.args.hdtype][self.args.nbatch][
                self.args.
                hdptype]['initial_step'] = self.nbatch_train * self.start_epoch
        self.lr_hash, update_step_op = decay_func(**DECAY_PARAMS_DICT[
            self.args.hdtype][self.args.nbatch][self.args.hdptype])

        update_ops = tf.get_collection("update_ops")

        var_slow_list, var_fast_list = list(), list()
        for var in tf.trainable_variables():
            if 'Hash' in var.name: var_fast_list.append(var)
            else: var_slow_list.append(var)

        with tf.control_dependencies(update_ops + [update_step_op]):
            self.train_op_hash = get_multi_train_op(
                tf.train.AdamOptimizer, self.loss_hash,
                [0.1 * self.lr_hash, self.lr_hash],
                [var_slow_list, var_fast_list])

        self.EMBED_HASH = self.anc_embed_k_hash if self.args.hltype == 'npair' else self.embed_k_hash
        self.max_k_idx = tf.nn.top_k(self.EMBED_HASH,
                                     k=self.args.k)[1]  # [batch_size, k]

        self.graph_ops_hash_dict = {
            'train': [self.train_op_hash, self.loss_hash],
            'val': self.loss_hash
        }
        self.logger.info("Model setting up train hash ends")

    def generate_sess(self):
        try:
            self.sess
        except AttributeError:
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(config=config)

    def initialize(self):
        '''Initialize uninitialized variables'''
        self.logger.info("Model initialization starts")
        self.generate_sess()
        rest_initializer(self.sess)
        self.start_epoch = 0
        val_p_dist = pairwise_distance_euclid_efficient(
            input1=self.val_embed,
            input2=self.val_embed,
            session=self.sess,
            batch_size=self.args.nbatch)
        self.logger.info("Calculating pairwise distance of validation data")
        self.val_arg_sort = np.argsort(val_p_dist, axis=1)
        self.logger.info("Model initialization ends")

    def save(self, global_step, save_dir):
        self.logger.info("Model save starts")
        for f in glob.glob(save_dir + '*'):
            os.remove(f)
        saver = tf.train.Saver(max_to_keep=5)
        saver.save(self.sess,
                   os.path.join(save_dir, 'model'),
                   global_step=global_step)
        self.logger.info("Model save in %s" % save_dir)
        self.logger.info("Model save ends")

    def save_hash(self, global_step, save_dir):
        self.logger.info("Model save starts")
        for f in glob.glob(save_dir + '*'):
            os.remove(f)
        saver = tf.train.Saver(max_to_keep=5)
        saver.save(self.sess,
                   os.path.join(save_dir, 'model'),
                   global_step=global_step)
        self.logger.info("Model save in %s" % save_dir)
        self.logger.info("Model save ends")

    def restore(self, save_dir):
        """Restore all variables in graph with the latest version"""
        self.logger.info("Restoring model starts...")
        saver = tf.train.Saver()
        latest_checkpoint = tf.train.latest_checkpoint(save_dir)
        self.logger.info("Restoring from {}".format(latest_checkpoint))
        self.generate_sess()
        saver.restore(self.sess, latest_checkpoint)
        self.logger.info("Restoring model done.")

    def restore_hash(self, save_dir):
        """Restore all variables in graph with the latest version"""
        self.logger.info("Restoring model starts...")
        saver = tf.train.Saver()
        latest_checkpoint = tf.train.latest_checkpoint(save_dir)
        self.logger.info("Restoring from {}".format(latest_checkpoint))
        self.start_epoch = int(
            os.path.basename(latest_checkpoint)[len('model') + 1:])
        self.genrate_sess()
        saver.restore(self.sess, latest_checkpoint)
        self.logger.info("Restoring model done.")

    def run_batch_hash(self, key='train'):
        '''
        self.args :
            key - string
                    train, test, val
        Return : 
            following graph operations
        '''
        assert key in ['train', 'test',
                       'val'], "key should be train or val or test"
        if self.args.hltype == 'npair':
            batch_anc_img, batch_pos_img, batch_anc_label, batch_pos_label = self.dataset_dict[
                key].next_batch(batch_size=self.args.nbatch)
            feed_dict = {
                self.anc_img: batch_anc_img,
                self.pos_img: batch_pos_img,
                self.label: batch_anc_label,
                self.istrain: True if key in ['train'] else False
            }

            # [self.args.nbatch//2, self.args.d]
            anc_unary, pos_unary = self.sess.run(
                [self.anc_embed_k_hash, self.pos_embed_k_hash],
                feed_dict=feed_dict)

            unary = 0.5 * (anc_unary + pos_unary)  # [batch_size//2, d]
            unary = np.mean(np.reshape(unary,
                                       [self.args.nsclass, -1, self.args.d]),
                            axis=1)  # [nsclass, d]

            results = self.mcf.solve(unary)
            objective = np.zeros([self.args.nsclass, self.args.d],
                                 dtype=np.float32)  # [nsclass, d]
            for i, j in results:
                objective[i][j] = 1
            objective = np.reshape(
                np.transpose(
                    np.tile(np.transpose(objective, [1, 0]),
                            [self.args.nbatch // (2 * self.args.nsclass), 1]),
                    [1, 0]),
                [self.args.nbatch // 2, self.args.d])  # [batch_size//2, d]
            feed_dict[self.objective] = objective
            return self.sess.run(self.graph_ops_hash_dict[key],
                                 feed_dict=feed_dict)
        else:
            batch_img, batch_label = self.dataset_dict[key].next_batch(
                batch_size=self.args.nbatch)
            feed_dict = {
                self.img: batch_img,
                self.label: batch_label,
                self.istrain: True if key in ['train'] else False
            }

            unary = self.sess.run(self.embed_k_hash_l2_norm,
                                  feed_dict=feed_dict)  # [nsclass, d]
            unary = np.mean(np.reshape(unary,
                                       [self.args.nsclass, -1, self.args.d]),
                            axis=1)  # [nsclass, d]

            results = self.mcf.solve(unary)
            objective = np.zeros([self.args.nsclass, self.args.d],
                                 dtype=np.float32)
            for i, j in results:
                objective[i][j] = 1
            objective = np.reshape(
                np.transpose(
                    np.tile(np.transpose(objective, [1, 0]),
                            [self.args.nbatch // self.args.nsclass, 1]),
                    [1, 0]), [self.args.nbatch, -1])  # [batch_size, d]
            feed_dict[self.objective] = objective
            return self.sess.run(self.graph_ops_hash_dict[key],
                                 feed_dict=feed_dict)

    def train_hash(self, epoch, save_dir, board_dir):
        self.logger.info("Model training starts")

        self.train_writer_hash = SummaryWriter(board_dir + 'train')
        self.val_writer_hash = SummaryWriter(board_dir + 'val')

        self.logger.info("Current epoch : {}/{}".format(
            self.start_epoch, epoch))
        self.logger.info("Current lr : {}".format(self.sess.run(self.lr_hash)))

        if self.args.hltype == 'npair':

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.anc_img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch // 2,
                                   dim=4,
                                   train_gate=self.istrain)
        else:  # triplet

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch,
                                   dim=4,
                                   train_gate=self.istrain)

        val_max_k_idx = custom_apply_tf_op(inputs=self.val_image,
                                           output_gate=self.max_k_idx)
        val_nmi, val_suf = get_nmi_suf_quick(index_array=val_max_k_idx,
                                             label_array=self.val_label,
                                             ncluster=self.args.d,
                                             nlabel=self.ncls_val)
        nsuccess = 0
        for i in range(self.nval):
            for j in self.val_arg_sort[i]:
                if i == j: continue
                if len(set(val_max_k_idx[j]) & set(val_max_k_idx[i])) > 0:
                    if self.val_label[i] == self.val_label[j]: nsuccess += 1
                    break
        val_p1 = nsuccess / self.nval
        max_val_p1 = val_p1
        self.val_writer_hash.add_summary("suf", val_suf, self.start_epoch)
        self.val_writer_hash.add_summary("nmi", val_nmi, self.start_epoch)
        self.val_writer_hash.add_summary("p1", val_p1, self.start_epoch)

        for epoch_ in range(self.start_epoch, epoch):
            train_epoch_loss = 0
            for _ in tqdm(range(self.nbatch_train), ascii=True, desc="batch"):
                _, batch_loss = self.run_batch_hash(key='train')
                train_epoch_loss += batch_loss

            val_max_k_idx = custom_apply_tf_op(inputs=self.val_image,
                                               output_gate=self.max_k_idx)
            val_nmi, val_suf = get_nmi_suf_quick(index_array=val_max_k_idx,
                                                 label_array=self.val_label,
                                                 ncluster=self.args.d,
                                                 nlabel=self.ncls_val)
            nsuccess = 0
            for i in range(self.nval):
                for j in self.val_arg_sort[i]:
                    if i == j: continue
                    if len(set(val_max_k_idx[j]) & set(val_max_k_idx[i])) > 0:
                        if self.val_label[i] == self.val_label[j]:
                            nsuccess += 1
                        break
            val_p1 = nsuccess / self.nval
            # averaging
            train_epoch_loss /= self.nbatch_train

            self.logger.info("Epoch({}/{}) train loss = {} val suf = {} val nmi = {} val p1 = {}"\
                    .format(epoch_ + 1, epoch, train_epoch_loss, val_suf, val_nmi, val_p1))

            self.train_writer_hash.add_summary("loss", train_epoch_loss,
                                               epoch_ + 1)
            self.train_writer_hash.add_summary("learning rate",
                                               self.sess.run(self.lr_hash),
                                               epoch_ + 1)
            self.val_writer_hash.add_summary("suf", val_suf, epoch_ + 1)
            self.val_writer_hash.add_summary("nmi", val_nmi, epoch_ + 1)
            self.val_writer_hash.add_summary("p1", val_p1, epoch_ + 1)

            if epoch_ == self.start_epoch or max_val_p1 < val_p1:
                max_val_p1 = val_p1
                self.save_hash(epoch_ + 1, save_dir)

        self.logger.info("Model training ends")

    def regen_session(self):
        tf.reset_default_graph()
        self.sess.close()
        self.sess = tf.Session()

    def prepare_test(self):
        self.logger.info("Model preparing test")
        if self.args.hltype == 'npair':

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.anc_img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch // 2,
                                   dim=4,
                                   train_gate=self.istrain)

            self.test_embed = custom_apply_tf_op(inputs=self.test_image,
                                                 output_gate=self.anc_embed)
            self.val_embed = custom_apply_tf_op(inputs=self.val_image,
                                                output_gate=self.anc_embed)
        else:  # triplet

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch,
                                   dim=4,
                                   train_gate=self.istrain)

            self.test_embed = custom_apply_tf_op(inputs=self.test_image,
                                                 output_gate=self.embed)
            self.val_embed = custom_apply_tf_op(inputs=self.val_image,
                                                output_gate=self.embed)

    def prepare_test_hash(self):
        self.logger.info("Model preparing test")
        if self.args.hltype == 'npair':

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.anc_img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch // 2,
                                   dim=4,
                                   train_gate=self.istrain)

            self.test_k_hash = custom_apply_tf_op(
                inputs=self.test_image, output_gate=self.anc_embed_k_hash)
        else:  # triplet

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch,
                                   dim=4,
                                   train_gate=self.istrain)

            self.test_k_hash = custom_apply_tf_op(
                inputs=self.test_image, output_gate=self.embed_k_hash_l2_norm)

    def test_hash_metric(self, activate_k, k_set):
        self.logger.info("Model testing k hash starts")
        self.logger.info("Activation k(={}) in buckets(={})".format(
            activate_k, self.args.d))

        self.regen_session()
        test_k_activate = activate_k_2D(self.test_k_hash,
                                        k=activate_k,
                                        session=self.sess)  # [ntest, args.d]
        if not hasattr(self, 'te_te_distance'):
            self.regen_session()
            self.te_te_distance = pairwise_distance_euclid_efficient(
                input1=self.test_embed,
                input2=self.test_embed,
                session=self.sess,
                batch_size=128)
            self.logger.info(
                "Calculating pairwise distance from test embeddings")

        performance = evaluate_hash_te(test_hash_key=test_k_activate, te_te_distance=self.te_te_distance,\
                                          te_te_query_key=test_k_activate, te_te_query_value=self.test_k_hash,\
                                          test_label=self.test_label, ncls_test=self.ncls_test,\
                                          activate_k=activate_k, k_set=k_set, logger=self.logger)

        self.logger.info("Model testing k hash ends")
        return performance

    def delete(self):
        tf.reset_default_graph()
        self.logger.remove()
        del self.logger