def test1():
    '''
    Results :
	unary potential :
	[[ 0.89292312  0.64831936  0.56726688  0.13208358  0.05779465  0.48978515]
	 [ 0.83382476  0.08014071  0.61772549  0.95149459  0.04179085  0.92253984]
	 [ 0.76766159  0.6634347   0.91049119  0.6748744   0.17438728  0.51890275]
	 [ 0.90997762  0.18447894  0.81440657  0.09081913  0.46642204  0.47917976]
	 [ 0.72631254  0.94356716  0.05386514  0.57434492  0.69070927  0.39979905]]
	objective :
	[[ 1.  1.  0.  0.  0.  0.]
	 [ 0.  0.  0.  1.  0.  1.]
	 [ 0.  0.  1.  1.  0.  0.]
	 [ 1.  0.  1.  0.  0.  0.]
	 [ 0.  1.  0.  0.  1.  0.]]
    '''
    pairwise_lamb = 0.1
    nworkers = 5
    ntasks = 6
    k = 2
    mcf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks,
                           k=k,
                           pairwise_lamb=pairwise_lamb)
    unary_potential = np.random.random([nworkers, ntasks])
    mcf_results = mcf.solve(unary_potential)
    objective = results2objective(results=mcf_results,
                                  nworkers=nworkers,
                                  ntasks=ntasks)

    print("unary potential :\n{}".format(unary_potential))
    print("objective :\n{}".format(objective))
def test6():
    '''Speed check'''
    # hyper parameter setting
    plamb = 0.0
    nworkers = 64
    ntasks = 8
    max_label = 8
    max_iter = 20

    mcf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks * max_label,
                           k=1,
                           pairwise_lamb=plamb)

    w_label_time = 0
    wo_label_time = 0

    for _ in range(max_iter):
        unary_potential = np.random.random([nworkers, ntasks])
        plabel = np.random.randint(max_label, size=nworkers)

        start_time = time.time()
        solvemaxmatching_label(unary_potential, plabel, max_label, plamb=plamb)
        end_time = time.time()
        w_label_time += end_time - start_time

        unary_potential = np.random.random([nworkers, ntasks * max_label])
        start_time = time.time()
        results2objective(results=mcf.solve(unary_potential),
                          nworkers=nworkers,
                          ntasks=ntasks * max_label)
        end_time = time.time()
        wo_label_time += end_time - start_time
    print("w label time : {}sec\nwo label time : {} sec".format(
        w_label_time, wo_label_time))
def test10():
    '''
    exp for rebutall
    '''
    pairwise_lamb = 1.0
    nworkers = 10
    ntasks = 2
    k = 1

    mcf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks,
                           k=k,
                           pairwise_lamb=pairwise_lamb)

    hnworkers = nworkers // 2
    unary_potential = np.array(hnworkers * [0.1, 0] +
                               (nworkers - hnworkers) * [0, 0.2])
    unary_potential = np.reshape(unary_potential, [nworkers, ntasks])
    print(unary_potential)
    mcf_objective = results2objective(results=mcf.solve(unary_potential),
                                      nworkers=nworkers,
                                      ntasks=ntasks)
    greedy_objective = greedy_max_matching_solver(array=unary_potential,
                                                  plamb=pairwise_lamb,
                                                  k=k)

    print("unary potential\n: {}".format(unary_potential))
    print("mcf objective\n{}\nvalue : {}".format(
        mcf_objective, get_value(unary_potential, mcf_objective,
                                 pairwise_lamb)))
    print("greedy objective\n{}\nvalue : {}".format(
        greedy_objective,
        get_value(unary_potential, greedy_objective, pairwise_lamb)))
def test5():
    '''Performance check'''
    # hyper parameter setting
    plamb = 0.1
    nworkers = 7
    ntasks = 5
    max_label = 2

    unary_potential = np.random.random([nworkers, ntasks])
    plabel = np.random.randint(max_label, size=nworkers)

    mcf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks,
                           k=1,
                           pairwise_lamb=plamb)
    objective = np.zeros([nworkers, ntasks], dtype=np.float32)
    for p_idx in range(max_label):
        results = mcf.solve_w_label(unary_potential, plabel, p_idx)
        for i, j in results:
            objective[i][j] = 1
    print(objective)
    print(
        solvemaxmatching_label(unary_potential, plabel, max_label,
                               plamb=plamb))
    print(
        results2objective(results=solve_maxmatching_soft_intraclass(
            unary_potential, plabel, plamb1=plamb, plamb2=0.0),
                          nworkers=nworkers,
                          ntasks=ntasks))
def test4():
    '''Speed check
    solvematching_label
    solve_maxmatching_soft_intraclass
    '''
    plamb = 0.0
    nworkers = 64
    ntasks = 8
    max_label = 8
    max_iter = 20

    mcf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks,
                           k=1,
                           pairwise_lamb=plamb)

    w_label_time_method1 = 0
    w_label_time_method2 = 0
    wo_label_time = 0
    soft_intraclass_time = 0

    for _ in range(max_iter):
        unary_potential = np.random.random([nworkers, ntasks])
        plabel = np.random.randint(max_label, size=nworkers)

        start_time = time.time()
        for p_idx in range(max_label):
            results2objective(results=mcf.solve_w_label(
                unary_potential, plabel, p_idx),
                              nworkers=nworkers,
                              ntasks=ntasks)
        end_time = time.time()
        w_label_time_method1 += end_time - start_time

        start_time = time.time()
        solvemaxmatching_label(unary_potential, plabel, max_label, plamb=plamb)
        end_time = time.time()
        w_label_time_method2 += end_time - start_time

        start_time = time.time()
        results2objective(results=mcf.solve(unary_potential),
                          nworkers=nworkers,
                          ntasks=ntasks)
        end_time = time.time()
        wo_label_time += end_time - start_time

        start_time = time.time()
        results2objective(results=solve_maxmatching_soft_intraclass(
            unary_potential, plabel, plamb1=plamb, plamb2=0.0),
                          nworkers=nworkers,
                          ntasks=ntasks)
        end_time = time.time()
        soft_intraclass_time += end_time - start_time
    print(
        "w label time method1 : {} sec\nw label time method2 : {} sec\nwo label time : {} sec\nsoft intraclass time : {} sec"
        .format(w_label_time_method1, w_label_time_method2, wo_label_time,
                soft_intraclass_time))
def test9():
    '''
    exp for rebutall
    '''
    pairwise_lamb = 0.3
    nworkers = 20
    ntasks = 4
    k = 1

    mcf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks,
                           k=k,
                           pairwise_lamb=pairwise_lamb)

    unary_potential = np.random.random([nworkers, ntasks])
    mcf_objective = results2objective(results=mcf.solve(unary_potential),
                                      nworkers=nworkers,
                                      ntasks=ntasks)
    greedy_objective = greedy_max_matching_solver(array=unary_potential,
                                                  plamb=pairwise_lamb,
                                                  k=k)

    max_unary_potential = unary_potential
    max_deviation = get_value(
        unary_potential, greedy_objective, pairwise_lamb) - get_value(
            unary_potential, mcf_objective, pairwise_lamb)

    for _ in range(100000):
        unary_potential = np.random.random([nworkers, ntasks])
        mcf_objective = results2objective(results=mcf.solve(unary_potential),
                                          nworkers=nworkers,
                                          ntasks=ntasks)
        greedy_objective = greedy_max_matching_solver(array=unary_potential,
                                                      plamb=pairwise_lamb,
                                                      k=k)
        deviation = get_value(
            unary_potential, greedy_objective, pairwise_lamb) - get_value(
                unary_potential, mcf_objective, pairwise_lamb)
        if deviation > max_deviation:
            max_deviation = deviation
            max_unary_potential = unary_potential

    unary_potential = max_unary_potential
    mcf_objective = results2objective(results=mcf.solve(unary_potential),
                                      nworkers=nworkers,
                                      ntasks=ntasks)
    greedy_objective = greedy_max_matching_solver(array=unary_potential,
                                                  plamb=pairwise_lamb,
                                                  k=k)

    print("unary potential : {}".format(unary_potential))
    print("mcf objective\n{}\nvalue : {}".format(
        mcf_objective, get_value(unary_potential, mcf_objective,
                                 pairwise_lamb)))
    print("greedy objective\n{}\nvalue : {}".format(
        greedy_objective,
        get_value(unary_potential, greedy_objective, pairwise_lamb)))
def test8():
    '''
    exp for rebutall
    '''
    pairwise_lamb = 1.0
    nworkers = 64
    ntasks = 64
    k = 1

    mcf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks,
                           k=k,
                           pairwise_lamb=pairwise_lamb)
    dem = DiscreteEnergyMinimize(ntasks, pairwise_lamb)
    unary_potential = np.random.random([nworkers, ntasks])

    print("unary potential : {}".format(unary_potential))

    ab_start_time = time.time()
    ab_objective = dem.solve(-unary_potential, np.ones([nworkers, nworkers]),
                             k)
    ab_end_time = time.time()

    mcf_start_time = time.time()
    mcf_objective = results2objective(results=mcf.solve(unary_potential),
                                      nworkers=nworkers,
                                      ntasks=ntasks)
    mcf_end_time = time.time()
    greedy_start_time = time.time()
    greedy_objective = greedy_max_matching_solver(array=unary_potential,
                                                  plamb=pairwise_lamb,
                                                  k=k)
    greedy_end_time = time.time()

    unary_results = list()
    for i in range(nworkers):
        unary_sort = np.argsort(-unary_potential[i])[:k]
        for j in unary_sort:
            unary_results.append([i, j])
    unary_objective = results2objective(results=unary_results,
                                        nworkers=nworkers,
                                        ntasks=ntasks)
    ab_time = ab_end_time - ab_start_time
    mcf_time = mcf_end_time - mcf_start_time
    greedy_time = greedy_end_time - greedy_start_time
    print("mcf({}sec)\nobjective\n{}\nvalue : {}".format(
        mcf_time, mcf_objective,
        get_value(unary_potential, mcf_objective, pairwise_lamb)))
    print("ab({}sec)\nobjective\n{}\nvalue : {}".format(
        ab_time, ab_objective,
        get_value(unary_potential, ab_objective, pairwise_lamb)))
    print("greedy({}sec)\nobjective\n{}\nvalue : {}".format(
        greedy_time, greedy_objective,
        get_value(unary_potential, greedy_objective, pairwise_lamb)))
    print("unary\nobjective\n{}\nvalue : {}".format(
        unary_objective,
        get_value(unary_potential, unary_objective, pairwise_lamb)))
def test19():
    '''timecompare
    solve_maxmatching_soft_intraclass_multiselect
    SolveMaxMatching

    Results -
        mcf : 0.8217835187911987 sec
        smsi : 0.08805792331695557 sec
    '''
    nworkers = 128
    sk = 2
    d = 32
    k = 2
    max_label = 32
    plamb = 0.1

    ntrial = 20
    mcf_time = 0
    smsi_time = 0

    for _ in range(ntrial):
        plabel = np.random.randint(max_label, size=nworkers)

        unary = np.random.random([nworkers, d**k])
        unary1 = np.random.random([nworkers, d])
        unary2 = np.random.random([nworkers, d])
        mcf = SolveMaxMatching(nworkers=nworkers,
                               ntasks=d**k,
                               k=sk,
                               pairwise_lamb=plamb)

        mcf_start_time = time.time()
        results = mcf.solve(unary)
        mcf_end_time = time.time()

        smsi_start_time = time.time()
        solve_maxmatching_soft_intraclass_multiselect(array=unary1,
                                                      k=sk,
                                                      labels=plabel,
                                                      plamb1=plamb,
                                                      plamb2=0.0)
        solve_maxmatching_soft_intraclass_multiselect(array=unary2,
                                                      k=sk,
                                                      labels=plabel,
                                                      plamb1=plamb,
                                                      plamb2=0.0)
        smsi_end_time = time.time()

        mcf_time += mcf_end_time - mcf_start_time
        smsi_time += smsi_end_time - smsi_start_time

    mcf_time /= ntrial
    smsi_time /= ntrial
    print("mcf : {} sec".format(mcf_time))
    print("smsi : {} sec".format(smsi_time))
    def build_hash(self):
        self.logger.info("Model building train hash starts")

        self.mcf = SolveMaxMatching(nworkers=self.args.nsclass,
                                    ntasks=self.args.d,
                                    k=self.args.k,
                                    pairwise_lamb=self.args.plamb)

        if self.args.hltype == 'triplet':
            self.objective = tf.placeholder(
                tf.float32, shape=[self.args.nbatch, self.args.d])
        else:
            self.objective = tf.placeholder(
                tf.float32, shape=[self.args.nbatch // 2, self.args.d])

        with slim.arg_scope(
            [slim.fully_connected],
                activation_fn=None,
                weights_regularizer=slim.l2_regularizer(0.0005),
                biases_initializer=tf.zeros_initializer(),
                weights_initializer=tf.truncated_normal_initializer(0.0,
                                                                    0.01)):
            if self.args.hltype == 'triplet':
                self.embed_k_hash = self.last
                with tf.variable_scope('Hash', reuse=False):
                    self.embed_k_hash = slim.fully_connected(self.embed_k_hash,
                                                             self.args.d,
                                                             scope="fc1")
                self.embed_k_hash_l2_norm = tf.nn.l2_normalize(
                    self.embed_k_hash,
                    dim=-1)  # embedding with l2 normalization
                self.pairwise_distance = PAIRWISE_DISTANCE_WITH_OBJECTIVE_DICT[
                    self.args.hdt]
                self.loss_hash = triplet_semihard_loss_hash(labels=self.label, embeddings=self.embed_k_hash_l2_norm, objectives=self.objective,\
                                pairwise_distance=self.pairwise_distance, margin=self.args.hma)
            else:
                self.anc_embed_k_hash = self.anc_last
                self.pos_embed_k_hash = self.pos_last
                with tf.variable_scope('Hash', reuse=False):
                    self.anc_embed_k_hash = slim.fully_connected(
                        self.anc_embed_k_hash, self.args.d, scope="fc1")
                with tf.variable_scope('Hash', reuse=True):
                    self.pos_embed_k_hash = slim.fully_connected(
                        self.pos_embed_k_hash, self.args.d, scope="fc1")
                self.similarity_func = PAIRWISE_SIMILARITY_WITH_OBJECTIVE_DICT[
                    self.args.hdt]
                self.loss_hash = npairs_loss_hash(labels=self.label, embeddings_anchor=self.anc_embed_k_hash, embeddings_positive=self.pos_embed_k_hash,\
                        objective=self.objective, similarity_func=self.similarity_func, reg_lambda=self.args.hlamb)
        self.logger.info("Model building train hash ends")
def test12():
    nclass = 64
    nworkers = 2 * nclass
    ntasks = 64
    k = 1
    plamb = 1.0
    iterations = 100

    mcf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks,
                           k=k,
                           pairwise_lamb=plamb)
    unary_p = np.random.random([nworkers, ntasks])

    def results2objective(results, nw, nt):
        objective = np.zeros([nw, nt])
        for i, j in results:
            objective[i][j] = 1
        return objective

    unary_p_t = np.zeros([nworkers, ntasks])

    for i in range(nclass):
        for j in range(ntasks):
            if unary_p[2 * i][j] < unary_p[2 * i + 1][j]:
                unary_p_t[2 * i][j] = unary_p[2 * i][j] + 0.5 * plamb
                unary_p_t[2 * i + 1][j] = unary_p[2 * i + 1][j] - 0.5 * plamb
            else:
                unary_p_t[2 * i][j] = unary_p[2 * i][j] - 0.5 * plamb
                unary_p_t[2 * i + 1][j] = unary_p[2 * i + 1][j] + 0.5 * plamb

    objective1 = results2objective(mcf.solve(unary_p_t), nworkers, ntasks)
    objective2 = results2objective(mcf.solve(unary_p), nworkers, ntasks)

    labels = np.reshape(
        np.tile(np.expand_dims(np.arange(nclass), axis=-1), (1, 2)), [-1])

    energy1 = get_value_label(unary_p, objective1, plamb, labels)
    energy2 = get_value_label(unary_p, objective2, plamb, labels)

    print(energy1, energy2)

    for _ in range(iterations):
        gr_objective = greedy_max_matching_label_solver_iter(
            unary_p, plamb, labels, _)
        print("{} :{}".format(
            _, get_value_label(unary_p, gr_objective, plamb, labels)))
def test17():
    ''' test for SolveMaxMatchingHungarian
    '''
    nsclass = 4
    ndata_per_class = 3
    ntasks = 6
    ndata = nsclass * ndata_per_class
    plamb = 0.1

    unary = np.random.random([ndata, ntasks])
    unary_class = np.mean(np.reshape(unary, [nsclass, -1, ntasks]), axis=1)

    mcf = SolveMaxMatching(nworkers=nsclass,
                           ntasks=ntasks,
                           k=ndata_per_class,
                           pairwise_lamb=plamb)
    hung = SolveMaxMatchingHungarian(nworkers=ndata_per_class,
                                     ntasks=ndata_per_class,
                                     k=1)

    results_summary = list()
    for i in range(nsclass):
        results_summary.append(list())

    results = mcf.solve(unary_class)
    for i, j in results:
        results_summary[i].append(j)

    objective = np.zeros([ndata, ntasks], dtype=np.float32)  # [nbatch, d]
    for i in range(nsclass):
        unary_tmp = np.zeros([ndata_per_class, ndata_per_class])
        for j1 in range(ndata_per_class):
            for j2 in range(ndata_per_class):
                unary_tmp[j1][j2] = unary[ndata_per_class * i +
                                          j1][results_summary[i][j2]]
        results = hung.solve(unary_tmp)
        for a, b in results:
            objective[ndata_per_class * i + a][results_summary[i][b]] = 1

    print("unary : \n{}".format(unary))
    print("unary_class : \n{}".format(unary_class))
    print("results_summary : \n{}".format(results_summary))
    print("objective : \n{}".format(objective))
def test3():
    '''Speed check
    test for SolveMaxMatching, solve_w_label
    '''
    # hyper parameter setting
    plamb = 0.0
    nworkers = 64
    ntasks = 8
    max_label = 8
    max_iter = 20

    mcf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks,
                           k=1,
                           pairwise_lamb=plamb)

    w_label_time = 0
    wo_label_time = 0
    for _ in range(max_iter):
        unary_potential = np.random.random([nworkers, ntasks])
        plabel = np.random.randint(max_label, size=nworkers)

        start_time = time.time()
        for p_idx in range(max_label):
            mcf.solve_w_label(unary_potential, plabel, p_idx)
        end_time = time.time()
        w_label_time += end_time - start_time

        start_time = time.time()
        mcf.solve(unary_potential)
        end_time = time.time()
        wo_label_time += end_time - start_time
    print("w label time : {} sec\nwo label time : {} sec".format(
        w_label_time, wo_label_time))
def test2():
    '''
    test for SolveMaxMatching, solve_w_label
    '''
    plamb = 0.0
    nworkers = 5
    ntasks = 4
    max_label = 2

    unary_potential = np.random.random([nworkers, ntasks])
    plabel = np.random.randint(max_label, size=nworkers)

    mcf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks,
                           k=1,
                           pairwise_lamb=plamb)
    print("unary potential :\n{}\nplabel:\n{}".format(unary_potential, plabel))
    for p_idx in range(max_label):
        objective = results2objective(results=mcf.solve_w_label(
            unary_potential, plabel, p_idx),
                                      nworkers=nworkers,
                                      ntasks=ntasks)
        print("p_idx : {}\nobjective\n{}".format(p_idx, objective))
def test1():
    '''
    test for SolveMaxMatching
    '''
    pairwise_lamb = 0.0
    nworkers = 5
    ntasks = 6
    k = 2

    unary_potential = np.random.random([nworkers, ntasks])
    acf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks,
                           k=k,
                           pairwise_lamb=pairwise_lamb)
    mcf_results = mcf.solve(unary_potential)
    objective = results2objective(results=mcf_results,
                                  nworkers=nworkers,
                                  ntasks=ntasks)

    print("unary potential:\n{}\nobjective:\n{}".format(
        unary_potential, objective))
class DeepMetric:
    def __init__(self, train_dataset, val_dataset, test_dataset, logfilepath,
                 args):
        self.args = args

        selectGpuById(self.args.gpu)
        self.logfilepath = logfilepath
        self.logger = LoggerManager(self.logfilepath, __name__)

        self.dataset_dict = dict()
        self.set_train_dataset(train_dataset)
        self.set_val_dataset(val_dataset)
        self.set_test_dataset(test_dataset)

    def set_train_dataset(self, train_dataset):
        self.logger.info("Setting train_dataset starts")
        self.train_dataset = train_dataset
        self.dataset_dict['train'] = self.train_dataset
        self.train_image = self.dataset_dict['train'].image
        self.train_label = self.dataset_dict['train'].label
        self.ntrain, self.height, self.width, self.nchannel = self.train_image.shape
        self.ncls_train = self.train_dataset.nclass
        self.nbatch_train = self.ntrain // self.args.nbatch
        self.logger.info("Setting train_dataset ends")

    def set_test_dataset(self, test_dataset):
        self.logger.info("Setting test_dataset starts")
        self.test_dataset = test_dataset
        self.dataset_dict['test'] = self.test_dataset
        self.test_image = self.dataset_dict['test'].image
        self.test_label = self.dataset_dict['test'].label
        self.ncls_test = self.test_dataset.nclass
        self.ntest = self.test_dataset.ndata
        self.nbatch_test = self.ntest // self.args.nbatch
        self.logger.info("Setting test_dataset ends")

    def set_val_dataset(self, val_dataset):
        self.logger.info("Setting val_dataset starts")
        self.val_dataset = val_dataset
        self.dataset_dict['val'] = self.val_dataset
        self.val_image = self.dataset_dict['val'].image
        self.val_label = self.dataset_dict['val'].label
        self.ncls_val = self.val_dataset.nclass
        self.nval = self.val_dataset.ndata
        self.nbatch_val = self.nval // self.args.nbatch
        self.logger.info("Setting val_dataset ends")

    def switch_log_path(self, logfilepath):
        self.logger.remove()
        print("Log file switched from {} to {}".format(self.logfilepath,
                                                       logfilepath))
        self.logfilepath = logfilepath
        self.logger = LoggerManager(self.logfilepath, __name__)

    def build(self, pretrain=False):
        self.logger.info("Model building starts")
        tf.reset_default_graph()
        if self.args.hltype == 'npair':
            self.anc_img = tf.placeholder(tf.float32,
                                          shape=[
                                              self.args.nbatch // 2,
                                              self.height, self.width,
                                              self.nchannel
                                          ])
            self.pos_img = tf.placeholder(tf.float32,
                                          shape=[
                                              self.args.nbatch // 2,
                                              self.height, self.width,
                                              self.nchannel
                                          ])
            self.istrain = tf.placeholder(tf.bool, shape=[])
            self.label = tf.placeholder(tf.int32,
                                        shape=[self.args.nbatch // 2])
        else:  # triplet
            self.img = tf.placeholder(tf.float32,
                                      shape=[
                                          self.args.nbatch, self.height,
                                          self.width, self.nchannel
                                      ])
            self.istrain = tf.placeholder(tf.bool, shape=[])
            self.label = tf.placeholder(tf.int32, shape=[self.args.nbatch])

        self.generate_sess()

        self.conv_net = CONV_DICT[self.args.dataset][self.args.conv]

        if self.args.hltype == 'npair':
            self.anc_last, _ = self.conv_net(self.anc_img,
                                             is_training=self.istrain,
                                             reuse=False)
            self.pos_last, _ = self.conv_net(self.pos_img,
                                             is_training=self.istrain,
                                             reuse=True)
            self.anc_last = tf.nn.relu(self.anc_last)
            self.pos_last = tf.nn.relu(self.pos_last)
        else:
            self.last, _ = self.conv_net(self.img,
                                         is_training=self.istrain,
                                         reuse=False)
            self.last = tf.nn.relu(self.last)

        with slim.arg_scope(
            [slim.fully_connected],
                activation_fn=None,
                weights_regularizer=slim.l2_regularizer(0.0005),
                biases_initializer=tf.zeros_initializer(),
                weights_initializer=tf.truncated_normal_initializer(0.0,
                                                                    0.01)):
            if self.args.hltype == 'npair':
                with tf.variable_scope('Embed', reuse=False):
                    self.anc_embed = slim.fully_connected(self.anc_last,
                                                          self.args.m,
                                                          scope="fc1")
                with tf.variable_scope('Embed', reuse=True):
                    self.pos_embed = slim.fully_connected(self.pos_last,
                                                          self.args.m,
                                                          scope="fc1")
            else:  #triplet
                with tf.variable_scope('Embed', reuse=False):
                    self.embed = slim.fully_connected(self.last,
                                                      self.args.m,
                                                      scope="fc1")

        initialized_variables = get_initialized_vars(self.sess)
        self.logger.info("Variables loaded from pretrained network\n{}".format(
            vars_info_vl(initialized_variables)))
        self.logger.info("Model building ends")

    def build_hash(self):
        self.logger.info("Model building train hash starts")

        self.mcf = SolveMaxMatching(nworkers=self.args.nsclass,
                                    ntasks=self.args.d,
                                    k=self.args.k,
                                    pairwise_lamb=self.args.plamb)

        if self.args.hltype == 'triplet':
            self.objective = tf.placeholder(
                tf.float32, shape=[self.args.nbatch, self.args.d])
        else:
            self.objective = tf.placeholder(
                tf.float32, shape=[self.args.nbatch // 2, self.args.d])

        with slim.arg_scope(
            [slim.fully_connected],
                activation_fn=None,
                weights_regularizer=slim.l2_regularizer(0.0005),
                biases_initializer=tf.zeros_initializer(),
                weights_initializer=tf.truncated_normal_initializer(0.0,
                                                                    0.01)):
            if self.args.hltype == 'triplet':
                self.embed_k_hash = self.last
                with tf.variable_scope('Hash', reuse=False):
                    self.embed_k_hash = slim.fully_connected(self.embed_k_hash,
                                                             self.args.d,
                                                             scope="fc1")
                self.embed_k_hash_l2_norm = tf.nn.l2_normalize(
                    self.embed_k_hash,
                    dim=-1)  # embedding with l2 normalization
                self.pairwise_distance = PAIRWISE_DISTANCE_WITH_OBJECTIVE_DICT[
                    self.args.hdt]
                self.loss_hash = triplet_semihard_loss_hash(labels=self.label, embeddings=self.embed_k_hash_l2_norm, objectives=self.objective,\
                                pairwise_distance=self.pairwise_distance, margin=self.args.hma)
            else:
                self.anc_embed_k_hash = self.anc_last
                self.pos_embed_k_hash = self.pos_last
                with tf.variable_scope('Hash', reuse=False):
                    self.anc_embed_k_hash = slim.fully_connected(
                        self.anc_embed_k_hash, self.args.d, scope="fc1")
                with tf.variable_scope('Hash', reuse=True):
                    self.pos_embed_k_hash = slim.fully_connected(
                        self.pos_embed_k_hash, self.args.d, scope="fc1")
                self.similarity_func = PAIRWISE_SIMILARITY_WITH_OBJECTIVE_DICT[
                    self.args.hdt]
                self.loss_hash = npairs_loss_hash(labels=self.label, embeddings_anchor=self.anc_embed_k_hash, embeddings_positive=self.pos_embed_k_hash,\
                        objective=self.objective, similarity_func=self.similarity_func, reg_lambda=self.args.hlamb)
        self.logger.info("Model building train hash ends")

    def set_up_train_hash(self):
        self.logger.info("Model setting up train hash starts")

        decay_func = DECAY_DICT[self.args.hdtype]
        if hasattr(self, 'start_epoch'):
            self.logger.info("Current start epoch : {}".format(
                self.start_epoch))
            DECAY_PARAMS_DICT[self.args.hdtype][self.args.nbatch][
                self.args.
                hdptype]['initial_step'] = self.nbatch_train * self.start_epoch
        self.lr_hash, update_step_op = decay_func(**DECAY_PARAMS_DICT[
            self.args.hdtype][self.args.nbatch][self.args.hdptype])

        update_ops = tf.get_collection("update_ops")

        var_slow_list, var_fast_list = list(), list()
        for var in tf.trainable_variables():
            if 'Hash' in var.name: var_fast_list.append(var)
            else: var_slow_list.append(var)

        with tf.control_dependencies(update_ops + [update_step_op]):
            self.train_op_hash = get_multi_train_op(
                tf.train.AdamOptimizer, self.loss_hash,
                [0.1 * self.lr_hash, self.lr_hash],
                [var_slow_list, var_fast_list])

        self.EMBED_HASH = self.anc_embed_k_hash if self.args.hltype == 'npair' else self.embed_k_hash
        self.max_k_idx = tf.nn.top_k(self.EMBED_HASH,
                                     k=self.args.k)[1]  # [batch_size, k]

        self.graph_ops_hash_dict = {
            'train': [self.train_op_hash, self.loss_hash],
            'val': self.loss_hash
        }
        self.logger.info("Model setting up train hash ends")

    def generate_sess(self):
        try:
            self.sess
        except AttributeError:
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            self.sess = tf.Session(config=config)

    def initialize(self):
        '''Initialize uninitialized variables'''
        self.logger.info("Model initialization starts")
        self.generate_sess()
        rest_initializer(self.sess)
        self.start_epoch = 0
        val_p_dist = pairwise_distance_euclid_efficient(
            input1=self.val_embed,
            input2=self.val_embed,
            session=self.sess,
            batch_size=self.args.nbatch)
        self.logger.info("Calculating pairwise distance of validation data")
        self.val_arg_sort = np.argsort(val_p_dist, axis=1)
        self.logger.info("Model initialization ends")

    def save(self, global_step, save_dir):
        self.logger.info("Model save starts")
        for f in glob.glob(save_dir + '*'):
            os.remove(f)
        saver = tf.train.Saver(max_to_keep=5)
        saver.save(self.sess,
                   os.path.join(save_dir, 'model'),
                   global_step=global_step)
        self.logger.info("Model save in %s" % save_dir)
        self.logger.info("Model save ends")

    def save_hash(self, global_step, save_dir):
        self.logger.info("Model save starts")
        for f in glob.glob(save_dir + '*'):
            os.remove(f)
        saver = tf.train.Saver(max_to_keep=5)
        saver.save(self.sess,
                   os.path.join(save_dir, 'model'),
                   global_step=global_step)
        self.logger.info("Model save in %s" % save_dir)
        self.logger.info("Model save ends")

    def restore(self, save_dir):
        """Restore all variables in graph with the latest version"""
        self.logger.info("Restoring model starts...")
        saver = tf.train.Saver()
        latest_checkpoint = tf.train.latest_checkpoint(save_dir)
        self.logger.info("Restoring from {}".format(latest_checkpoint))
        self.generate_sess()
        saver.restore(self.sess, latest_checkpoint)
        self.logger.info("Restoring model done.")

    def restore_hash(self, save_dir):
        """Restore all variables in graph with the latest version"""
        self.logger.info("Restoring model starts...")
        saver = tf.train.Saver()
        latest_checkpoint = tf.train.latest_checkpoint(save_dir)
        self.logger.info("Restoring from {}".format(latest_checkpoint))
        self.start_epoch = int(
            os.path.basename(latest_checkpoint)[len('model') + 1:])
        self.genrate_sess()
        saver.restore(self.sess, latest_checkpoint)
        self.logger.info("Restoring model done.")

    def run_batch_hash(self, key='train'):
        '''
        self.args :
            key - string
                    train, test, val
        Return : 
            following graph operations
        '''
        assert key in ['train', 'test',
                       'val'], "key should be train or val or test"
        if self.args.hltype == 'npair':
            batch_anc_img, batch_pos_img, batch_anc_label, batch_pos_label = self.dataset_dict[
                key].next_batch(batch_size=self.args.nbatch)
            feed_dict = {
                self.anc_img: batch_anc_img,
                self.pos_img: batch_pos_img,
                self.label: batch_anc_label,
                self.istrain: True if key in ['train'] else False
            }

            # [self.args.nbatch//2, self.args.d]
            anc_unary, pos_unary = self.sess.run(
                [self.anc_embed_k_hash, self.pos_embed_k_hash],
                feed_dict=feed_dict)

            unary = 0.5 * (anc_unary + pos_unary)  # [batch_size//2, d]
            unary = np.mean(np.reshape(unary,
                                       [self.args.nsclass, -1, self.args.d]),
                            axis=1)  # [nsclass, d]

            results = self.mcf.solve(unary)
            objective = np.zeros([self.args.nsclass, self.args.d],
                                 dtype=np.float32)  # [nsclass, d]
            for i, j in results:
                objective[i][j] = 1
            objective = np.reshape(
                np.transpose(
                    np.tile(np.transpose(objective, [1, 0]),
                            [self.args.nbatch // (2 * self.args.nsclass), 1]),
                    [1, 0]),
                [self.args.nbatch // 2, self.args.d])  # [batch_size//2, d]
            feed_dict[self.objective] = objective
            return self.sess.run(self.graph_ops_hash_dict[key],
                                 feed_dict=feed_dict)
        else:
            batch_img, batch_label = self.dataset_dict[key].next_batch(
                batch_size=self.args.nbatch)
            feed_dict = {
                self.img: batch_img,
                self.label: batch_label,
                self.istrain: True if key in ['train'] else False
            }

            unary = self.sess.run(self.embed_k_hash_l2_norm,
                                  feed_dict=feed_dict)  # [nsclass, d]
            unary = np.mean(np.reshape(unary,
                                       [self.args.nsclass, -1, self.args.d]),
                            axis=1)  # [nsclass, d]

            results = self.mcf.solve(unary)
            objective = np.zeros([self.args.nsclass, self.args.d],
                                 dtype=np.float32)
            for i, j in results:
                objective[i][j] = 1
            objective = np.reshape(
                np.transpose(
                    np.tile(np.transpose(objective, [1, 0]),
                            [self.args.nbatch // self.args.nsclass, 1]),
                    [1, 0]), [self.args.nbatch, -1])  # [batch_size, d]
            feed_dict[self.objective] = objective
            return self.sess.run(self.graph_ops_hash_dict[key],
                                 feed_dict=feed_dict)

    def train_hash(self, epoch, save_dir, board_dir):
        self.logger.info("Model training starts")

        self.train_writer_hash = SummaryWriter(board_dir + 'train')
        self.val_writer_hash = SummaryWriter(board_dir + 'val')

        self.logger.info("Current epoch : {}/{}".format(
            self.start_epoch, epoch))
        self.logger.info("Current lr : {}".format(self.sess.run(self.lr_hash)))

        if self.args.hltype == 'npair':

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.anc_img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch // 2,
                                   dim=4,
                                   train_gate=self.istrain)
        else:  # triplet

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch,
                                   dim=4,
                                   train_gate=self.istrain)

        val_max_k_idx = custom_apply_tf_op(inputs=self.val_image,
                                           output_gate=self.max_k_idx)
        val_nmi, val_suf = get_nmi_suf_quick(index_array=val_max_k_idx,
                                             label_array=self.val_label,
                                             ncluster=self.args.d,
                                             nlabel=self.ncls_val)
        nsuccess = 0
        for i in range(self.nval):
            for j in self.val_arg_sort[i]:
                if i == j: continue
                if len(set(val_max_k_idx[j]) & set(val_max_k_idx[i])) > 0:
                    if self.val_label[i] == self.val_label[j]: nsuccess += 1
                    break
        val_p1 = nsuccess / self.nval
        max_val_p1 = val_p1
        self.val_writer_hash.add_summary("suf", val_suf, self.start_epoch)
        self.val_writer_hash.add_summary("nmi", val_nmi, self.start_epoch)
        self.val_writer_hash.add_summary("p1", val_p1, self.start_epoch)

        for epoch_ in range(self.start_epoch, epoch):
            train_epoch_loss = 0
            for _ in tqdm(range(self.nbatch_train), ascii=True, desc="batch"):
                _, batch_loss = self.run_batch_hash(key='train')
                train_epoch_loss += batch_loss

            val_max_k_idx = custom_apply_tf_op(inputs=self.val_image,
                                               output_gate=self.max_k_idx)
            val_nmi, val_suf = get_nmi_suf_quick(index_array=val_max_k_idx,
                                                 label_array=self.val_label,
                                                 ncluster=self.args.d,
                                                 nlabel=self.ncls_val)
            nsuccess = 0
            for i in range(self.nval):
                for j in self.val_arg_sort[i]:
                    if i == j: continue
                    if len(set(val_max_k_idx[j]) & set(val_max_k_idx[i])) > 0:
                        if self.val_label[i] == self.val_label[j]:
                            nsuccess += 1
                        break
            val_p1 = nsuccess / self.nval
            # averaging
            train_epoch_loss /= self.nbatch_train

            self.logger.info("Epoch({}/{}) train loss = {} val suf = {} val nmi = {} val p1 = {}"\
                    .format(epoch_ + 1, epoch, train_epoch_loss, val_suf, val_nmi, val_p1))

            self.train_writer_hash.add_summary("loss", train_epoch_loss,
                                               epoch_ + 1)
            self.train_writer_hash.add_summary("learning rate",
                                               self.sess.run(self.lr_hash),
                                               epoch_ + 1)
            self.val_writer_hash.add_summary("suf", val_suf, epoch_ + 1)
            self.val_writer_hash.add_summary("nmi", val_nmi, epoch_ + 1)
            self.val_writer_hash.add_summary("p1", val_p1, epoch_ + 1)

            if epoch_ == self.start_epoch or max_val_p1 < val_p1:
                max_val_p1 = val_p1
                self.save_hash(epoch_ + 1, save_dir)

        self.logger.info("Model training ends")

    def regen_session(self):
        tf.reset_default_graph()
        self.sess.close()
        self.sess = tf.Session()

    def prepare_test(self):
        self.logger.info("Model preparing test")
        if self.args.hltype == 'npair':

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.anc_img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch // 2,
                                   dim=4,
                                   train_gate=self.istrain)

            self.test_embed = custom_apply_tf_op(inputs=self.test_image,
                                                 output_gate=self.anc_embed)
            self.val_embed = custom_apply_tf_op(inputs=self.val_image,
                                                output_gate=self.anc_embed)
        else:  # triplet

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch,
                                   dim=4,
                                   train_gate=self.istrain)

            self.test_embed = custom_apply_tf_op(inputs=self.test_image,
                                                 output_gate=self.embed)
            self.val_embed = custom_apply_tf_op(inputs=self.val_image,
                                                output_gate=self.embed)

    def prepare_test_hash(self):
        self.logger.info("Model preparing test")
        if self.args.hltype == 'npair':

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.anc_img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch // 2,
                                   dim=4,
                                   train_gate=self.istrain)

            self.test_k_hash = custom_apply_tf_op(
                inputs=self.test_image, output_gate=self.anc_embed_k_hash)
        else:  # triplet

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch,
                                   dim=4,
                                   train_gate=self.istrain)

            self.test_k_hash = custom_apply_tf_op(
                inputs=self.test_image, output_gate=self.embed_k_hash_l2_norm)

    def test_hash_metric(self, activate_k, k_set):
        self.logger.info("Model testing k hash starts")
        self.logger.info("Activation k(={}) in buckets(={})".format(
            activate_k, self.args.d))

        self.regen_session()
        test_k_activate = activate_k_2D(self.test_k_hash,
                                        k=activate_k,
                                        session=self.sess)  # [ntest, args.d]
        if not hasattr(self, 'te_te_distance'):
            self.regen_session()
            self.te_te_distance = pairwise_distance_euclid_efficient(
                input1=self.test_embed,
                input2=self.test_embed,
                session=self.sess,
                batch_size=128)
            self.logger.info(
                "Calculating pairwise distance from test embeddings")

        performance = evaluate_hash_te(test_hash_key=test_k_activate, te_te_distance=self.te_te_distance,\
                                          te_te_query_key=test_k_activate, te_te_query_value=self.test_k_hash,\
                                          test_label=self.test_label, ncls_test=self.ncls_test,\
                                          activate_k=activate_k, k_set=k_set, logger=self.logger)

        self.logger.info("Model testing k hash ends")
        return performance

    def delete(self):
        tf.reset_default_graph()
        self.logger.remove()
        del self.logger
def test11():
    nclass = 64
    nworkers = 2 * nclass
    ntasks = 5
    k = 1
    plamb = 0.1
    iterations = 100

    nsucess0 = 0
    nsucess1 = 0
    nsucess2 = 0
    nsucess3 = 0
    nsucess4 = 0
    nsucess5 = 0

    mcf = SolveMaxMatching(nworkers=nworkers,
                           ntasks=ntasks,
                           k=k,
                           pairwise_lamb=plamb)
    unary_p = np.random.random([nworkers, ntasks])

    def results2objective(results, nw, nt):
        objective = np.zeros([nw, nt])
        for i, j in results:
            objective[i][j] = 1
        return objective

    unary_p_t = np.zeros([nworkers, ntasks])
    unary_p_t2 = np.zeros([nworkers, ntasks])

    for i in range(nclass):
        for j in range(ntasks):
            if unary_p[2 * i][j] < unary_p[2 * i + 1][j]:
                unary_p_t[2 * i][j] = unary_p[2 * i][j] + 0.5 * plamb
                unary_p_t[2 * i + 1][j] = unary_p[2 * i + 1][j] - 0.5 * plamb
            else:
                unary_p_t[2 * i][j] = unary_p[2 * i][j] - 0.5 * plamb
                unary_p_t[2 * i + 1][j] = unary_p[2 * i + 1][j] + 0.5 * plamb

    for i in range(nclass):
        for j in range(ntasks):
            if unary_p[2 * i][j] < unary_p[2 * i + 1][j]:
                unary_p_t[2 * i][j] = unary_p[2 * i][j] + 0.5 * plamb
                unary_p_t[2 * i + 1][j] = unary_p[2 * i + 1][j] - 0.5 * plamb
            else:
                unary_p_t[2 * i][j] = unary_p[2 * i][j] - 0.5 * plamb
                unary_p_t[2 * i + 1][j] = unary_p[2 * i + 1][j] + 0.5 * plamb

    objective1 = results2objective(mcf.solve(unary_p_t), nworkers, ntasks)
    objective2 = results2objective(mcf.solve(unary_p), nworkers, ntasks)
    objective3 = results2objective(mcf.solve(unary_p_t2), nworkers, ntasks)

    DEM = DiscreteEnergyMinimize(ntasks, plamb)
    pairwise_term = np.zeros([nworkers, nworkers])
    for i in range(nclass):
        for j in range(i + 1, nclass):
            pairwise_term[2 * i][2 * j] = 1
            pairwise_term[2 * i + 1][2 * j] = 1
            pairwise_term[2 * i][2 * j + 1] = 1
            pairwise_term[2 * i + 1][2 * j + 1] = 1

    #print("pairwise : {}".format(pairwise_term))
    ab_objective = DEM.solve(-unary_p, pairwise_term, k=k)
    #print("alpha beta objective : \n{}".format(ab_objective))

    results = list()
    for i in range(nworkers):
        random = np.argsort(-unary_p[i][:k])
        for j in random:
            results.append([i, j])
    tr_objective = results2objective(np.array(results), nworkers, ntasks)

    labels = np.reshape(
        np.tile(np.expand_dims(np.arange(nclass), axis=-1), (1, 2)), [-1])
    gr_objective = greedy_max_matching_label_solver_iter(
        unary_p, plamb, labels, 10)

    energy0 = get_value_label(unary_p, objective1, plamb, labels)
    energy1 = get_value_label(unary_p, objective2, plamb, labels)
    energy2 = get_value_label(unary_p, objective3, plamb, labels)
    energy3 = get_value_label(unary_p, ab_objective, plamb, labels)
    energy4 = get_value_label(unary_p, tr_objective, plamb, labels)
    energy5 = get_value_label(unary_p, gr_objective, plamb, labels)

    min_energy = min(energy0, energy1, energy2, energy3, energy4, energy5)
    print(energy0, energy1, energy2, energy3, energy4, energy5)
from csv_op import CsvWriter3

import numpy as np
import time
import os

n_c_set = [32, 64, 128, 256, 512]
d_set = [32, 64, 128, 256, 512]
k = 4
lamb = 1.0
n_iter = 20

cw = CsvWriter3()

for idx1 in range(len(n_c_set)):
    for idx2 in range(len(d_set)):
        print("idx : {}, {}".format(idx1, idx2))
        mcf = SolveMaxMatching(nworkers=n_c_set[idx1], ntasks=d_set[idx2], k=k, pairwise_lamb=lamb)
        time_record = 0 
        for _ in range(n_iter):
            unary = np.random.random([n_c_set[idx1], d_set[idx2]])
            start_time = time.time()
            results = mcf.solve(unary)
            end_time = time.time()
            time_record+=end_time-start_time
        time_record/=n_iter

        cw.add_content(n_c_set[idx1], d_set[idx2], time_record)

cw.write('./time_record.csv', n_c_set, d_set)