def build(self):
        self.logger.info("Model building starts")
        tf.reset_default_graph()
        tf.set_random_seed(self.args.rseed)

        self.input1 = tf.placeholder(tf.float32, shape = [self.args.nbatch, self.height, self.width, self.nchannel])
        self.epsilon_input = tf.placeholder(tf.float32, shape=[self.args.nbatch, self.args.nconti])
        self.objective = tf.placeholder(tf.float32, shape = [self.args.nbatch, self.args.ncat])
        self.istrain = tf.placeholder(tf.bool, shape= [])

        self.generate_sess()

        self.mcf = SolveMaxMatching(nworkers=self.args.nbatch, ntasks=self.args.ncat, k=1, pairwise_lamb=self.args.plamb)
        # Encoding
        self.encoder_net = lie_encoder1_64
        # self.encoder_net = encoder1_64
        self.decoder_net = lie_decoder1_64

        # Continuous rep
        encode_dict = self.encoder_net(self.input1, output_dim=2*self.args.nconti, scope='encoder', group_feats_size=self.args.group_feats_size, reuse=False)
        self.mean_total, self.stddev_total = tf.split(encode_dict['output'], num_or_size_splits=2, axis=1)
        self.enc_gfeats_mat = encode_dict['gfeats_mat']
        self.stddev_total = tf.nn.softplus(self.stddev_total)
        self.z_sample = tf.add(self.mean_total, tf.multiply(self.stddev_total, self.epsilon_input))

        decode_dict = self.decoder_net(z=tf.concat([self.z_sample, self.objective], axis=-1), output_channel=self.nchannel, nconti=self.args.nconti, ncat=self.args.ncat, group_feats_size=self.args.group_feats_size, scope="decoder", lie_norm_type=self.args.lie_norm_type, reuse=False)
        self.dec_output = decode_dict['output']
        self.dec_lie_group_mat = decode_dict['lie_group_mat']
        self.dec_lie_alg = decode_dict['lie_alg']
        self.lie_alg_basis = decode_dict['lie_alg_basis'] # [1, lat_dim, mat_dim, mat_dim]

        # Unary vector
        self.rec_cost_vector = sigmoid_cross_entropy_without_mean(labels=self.input1, logits=self.dec_output)

        # Loss
        self.rec_cost = tf.reduce_mean(self.rec_cost_vector)

        self.kl_cost = vae_kl_cost(mean=self.mean_total, stddev=self.stddev_total)
        self.lie_loss = self.calc_lie_loss(self.enc_gfeats_mat, self.dec_lie_group_mat, self.dec_lie_alg, self.lie_alg_basis, self.args.nbatch)
        self.loss = self.rec_cost + self.kl_cost + self.lie_loss

        # Decode
        self.latent_ph = tf.placeholder(tf.float32, shape = [self.args.nbatch, self.args.nconti+self.args.ncat])
        self.dec_output_ph = tf.nn.sigmoid(self.decoder_net(z=self.latent_ph, output_channel=self.nchannel, nconti=self.args.nconti, ncat=self.args.ncat, group_feats_size=self.args.group_feats_size, scope="decoder", lie_norm_type=self.args.lie_norm_type, reuse=True)['output'])

        self.logger.info("Model building ends")
class Model(ModelPlugin):
    def __init__(self, dataset, logfilepath, args):
        super().__init__(dataset, logfilepath, args)
        self.build()

    def build(self):
        self.logger.info("Model building starts")
        tf.reset_default_graph()
        tf.set_random_seed(self.args.rseed)

        self.input1 = tf.placeholder(
            tf.float32,
            shape=[self.args.nbatch, self.height, self.width, self.nchannel])
        self.epsilon_input = tf.placeholder(
            tf.float32, shape=[self.args.nbatch, self.args.nconti])
        self.objective = tf.placeholder(
            tf.float32, shape=[self.args.nbatch, self.args.ncat])
        self.istrain = tf.placeholder(tf.bool, shape=[])

        self.generate_sess()

        self.mcf = SolveMaxMatching(nworkers=self.args.nbatch,
                                    ntasks=self.args.ncat,
                                    k=1,
                                    pairwise_lamb=self.args.plamb)
        # Encoding
        self.encoder_net = encoder1_64
        self.decoder_net = group_spl_decoder1_64

        # Continuous rep
        self.mean_total, self.stddev_total = tf.split(self.encoder_net(
            self.input1,
            output_dim=2 * self.args.nconti,
            scope='encoder',
            reuse=False)['output'],
                                                      num_or_size_splits=2,
                                                      axis=1)
        # encode_dict = self.encoder_net(self.input1, output_dim=2*self.args.nconti, scope='encoder', group_feats_size=self.args.group_feats_size, reuse=False)
        # self.mean_total, self.stddev_total = tf.split(encode_dict['output'], num_or_size_splits=2, axis=1)
        # self.enc_gfeats_mat = encode_dict['gfeats_mat']
        self.stddev_total = tf.nn.softplus(self.stddev_total)
        self.z_sample = tf.add(
            self.mean_total, tf.multiply(self.stddev_total,
                                         self.epsilon_input))

        self.z_sample_sum = self.z_sample[:self.args.nbatch //
                                          2] + self.z_sample[self.args.
                                                             nbatch // 2:]
        # z_sampled_split_ls = split_latents(self.z_sample, self.args.nbatch, hy_ncut=self.args.ncut)
        # self.z_sampled_split = tf.concat(z_sampled_split_ls, axis=0)
        # self.objective_split = tf.tile(self.objective, [len(z_sampled_split_ls), 1])

        # self.z_sample_all = tf.concat([self.z_sample, self.z_sample_sum, self.z_sampled_split], axis=0)
        # self.objective_all = tf.concat([self.objective, self.objective[:self.args.nbatch // 2], self.objective_split], axis=0)

        self.z_sample_all = tf.concat([self.z_sample, self.z_sample_sum],
                                      axis=0)
        self.objective_all = tf.concat(
            [self.objective, self.objective[:self.args.nbatch // 2]], axis=0)

        decode_dict = self.decoder_net(
            z=tf.concat([self.z_sample_all, self.objective_all], axis=-1),
            output_channel=self.nchannel,
            n_act_points=self.args.n_act_points,
            lie_alg_init_type=self.args.lie_alg_init_type,
            nconti=self.args.nconti,
            ncat=self.args.ncat,
            group_feats_size=self.args.group_feats_size,
            ncut=self.args.ncut,
            scope="decoder",
            reuse=False,
            is_train=self.istrain)
        self.dec_output = decode_dict['output']
        self.dec_lie_group_mat = decode_dict['lie_group_mat']
        self.dec_lie_alg = decode_dict['lie_alg']
        self.lie_alg_basis = decode_dict[
            'lie_alg_basis']  # [1, lat_dim, mat_dim, mat_dim]
        self.act_points = decode_dict[
            'act_points']  # [b, mat_dim, n_act_points]

        # Unary vector
        self.rec_cost_vector = sigmoid_cross_entropy_without_mean(
            labels=self.input1, logits=self.dec_output[:self.args.nbatch])

        # Loss
        self.rec_cost = tf.reduce_mean(self.rec_cost_vector)

        self.kl_cost = vae_kl_cost(mean=self.mean_total,
                                   stddev=self.stddev_total)
        self.lie_loss = self.calc_lie_loss(self.dec_lie_group_mat,
                                           self.dec_lie_alg,
                                           self.lie_alg_basis, self.act_points,
                                           self.args.hessian_type,
                                           self.args.nbatch)
        self.loss = self.rec_cost + self.args.beta * self.kl_cost + self.lie_loss

        # Decode
        self.latent_ph = tf.placeholder(
            tf.float32,
            shape=[self.args.nbatch, self.args.nconti + self.args.ncat])
        self.dec_output_ph = tf.nn.sigmoid(
            self.decoder_net(z=self.latent_ph,
                             output_channel=self.nchannel,
                             n_act_points=self.args.n_act_points,
                             lie_alg_init_type=self.args.lie_alg_init_type,
                             nconti=self.args.nconti,
                             ncat=self.args.ncat,
                             group_feats_size=self.args.group_feats_size,
                             ncut=self.args.ncut,
                             scope="decoder",
                             reuse=True,
                             is_train=self.istrain)['output'])

        self.logger.info("Model building ends")

    def calc_lie_loss(self, group_feats_G, dec_lie_alg, lie_alg_basis,
                      act_points, hessian_type, nbatch):
        mat_dim = group_feats_G.get_shape().as_list()[1]

        group_feats_G_ori = group_feats_G[:nbatch]
        group_feats_G_sum = group_feats_G[nbatch:nbatch + nbatch // 2]
        # gfeats_G_split_ls = [
        # group_feats_G[(i + 1) * nbatch + nbatch // 2:
        # (i + 2) * nbatch + nbatch // 2]
        # for i in range(self.args.ncut + 1)
        # ]

        group_feats_G_mul = tf.matmul(group_feats_G[:nbatch // 2],
                                      group_feats_G[nbatch // 2:nbatch])

        # gfeats_G_split_mul = gfeats_G_split_ls[0]
        # for i in range(1, self.args.ncut + 1):
        # gfeats_G_split_mul = tf.matmul(gfeats_G_split_mul,
        # gfeats_G_split_ls[i])

        lie_alg_basis_square = lie_alg_basis * lie_alg_basis
        # [1, lat_dim, mat_dim, mat_dim]
        _, lat_dim, mat_dim, _ = lie_alg_basis.get_shape().as_list()
        lie_alg_basis_col = tf.reshape(lie_alg_basis,
                                       [lat_dim, 1, mat_dim, mat_dim])
        lie_alg_basis_mul = tf.matmul(lie_alg_basis, lie_alg_basis_col)
        lie_alg_basis_mask = 1. - tf.eye(
            lat_dim, dtype=lie_alg_basis_mul.dtype)[:, :, tf.newaxis,
                                                    tf.newaxis]
        lie_alg_basis_mul = lie_alg_basis_mul * lie_alg_basis_mask

        gmat_loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(group_feats_G_mul - group_feats_G_sum),
                          axis=[1, 2]))
        # spl_loss = tf.reduce_mean(
        # tf.reduce_sum(tf.square(gfeats_G_split_mul - group_feats_G_ori),
        # axis=[1, 2]))
        lin_loss = tf.reduce_mean(
            tf.reduce_sum(lie_alg_basis_square, axis=[2, 3]))
        if hessian_type == 'no_act_points':
            hessian_loss = tf.reduce_mean(
                tf.reduce_sum(tf.square(lie_alg_basis_mul), axis=[2, 3]))
        elif hessian_type == 'with_act_points':
            # act_points: [b, mat_dim, n_act_points]
            # lie_alg_basis_mul: [lat_dim, lat_dim, mat_dim, mat_dim]
            # For more efficient impl, we use act_points[:1] here.
            lie_act_mul = tf.matmul(lie_alg_basis_mul, act_points[:1])
            # [lat_dim, lat_dim, mat_dim, n_act_points]
            # print('lie_act_mul.shape:', lie_act_mul.get_shape().as_list())
            hessian_loss = tf.reduce_mean(
                tf.reduce_sum(tf.square(lie_act_mul), axis=[2, 3]))
        else:
            raise ValueError('Not recognized hessian_type:', hessian_type)

        # loss = self.args.gmat * gmat_loss * self.args.spl * spl_loss + self.args.hes * hessian_loss + self.args.lin * lin_loss
        loss = self.args.gmat * gmat_loss + self.args.hes * hessian_loss + self.args.lin * lin_loss
        return loss

    def decode(self, latent_input):
        return apply_tf_op(inputs=latent_input,
                           session=self.sess,
                           input_gate=self.latent_ph,
                           output_gate=self.dec_output_ph,
                           batch_size=self.args.nbatch,
                           train_gate=self.istrain)

    def set_up_train(self):
        self.logger.info("Model setting up train starts")

        if not hasattr(self, 'start_iter'): self.start_iter = 0
        self.logger.info("Start iter: {}".format(self.start_iter))

        decay_func = DECAY_DICT[self.args.dtype]
        decay_params = DECAY_PARAMS_DICT[self.args.dtype][self.args.nbatch][
            self.args.dptype].copy()
        decay_params['initial_step'] = self.start_iter

        self.lr, update_step_op = decay_func(**decay_params)
        self.update_step_op = [update_step_op]

        var_list = [
            v for v in tf.trainable_variables() if 'encoder' in v.name
        ] + [v for v in tf.trainable_variables() if 'decoder' in v.name]

        # self.train_op_dict = dict()
        # with tf.control_dependencies(tf.get_collection("update_ops")):
        # for idx in range(self.args.nconti+1):
        # self.train_op_dict[idx] = get_train_op_v2(tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.9, beta2=0.999), loss=self.loss_dict[idx], var_list=var_list)
        self.train_op = get_train_op_v2(tf.train.AdamOptimizer(
            learning_rate=self.lr, beta1=0.9, beta2=0.999),
                                        loss=self.loss,
                                        var_list=var_list)

        self.logger.info("Model setting up train ends")

    def run_batch(self, train_idx):
        feed_dict = dict()
        feed_dict[self.input1] = self.dataset.next_batch(
            batch_size=self.args.nbatch)[0]
        feed_dict[self.istrain] = True
        feed_dict[self.epsilon_input] = np.random.normal(
            size=[self.args.nbatch, self.args.nconti])

        if train_idx < self.args.ntime:
            feed_dict[self.objective] = np.zeros(
                [self.args.nbatch, self.args.ncat])
        else:
            unary = np.zeros([self.args.nbatch, self.args.ncat])
            for idx in range(self.args.ncat):
                feed_dict[self.objective] = np.tile(
                    np.reshape(np.eye(self.args.ncat)[idx], [1, -1]),
                    [self.args.nbatch, 1])
                unary[:, idx] = self.sess.run(self.rec_cost_vector,
                                              feed_dict=feed_dict)
            feed_dict[self.objective] = self.mcf.solve(-unary)[1]

        # if train_idx>=self.args.ntime:
        # idx = min(train_idx, self.args.nconti)
        # else:
        # idx = min(train_idx+1, self.args.nconti)
        self.sess.run(self.train_op, feed_dict=feed_dict)

    def train(self, niter, piter, siter, save_dir=None, asset_dir=None):
        self.logger.info("Model training starts")

        final_iter = self.start_iter + niter
        max_accuracy = -1
        max_acc_iter = -1

        for iter_ in tqdm_range(self.start_iter, final_iter):
            # train_idx = (iter_ - self.start_iter)//piter
            train_idx = iter_ // piter
            self.run_batch(train_idx)

            if (iter_ + 1) % siter == 0 or iter_ + 1 == final_iter:
                include_discrete = False if train_idx < self.args.ntime else True
                accuracy = self.evaluate(include_discrete=include_discrete)

                self.latent_traversal_gif(path=asset_dir +
                                          '{}.gif'.format(iter_ + 1),
                                          include_discrete=include_discrete)
                if max_accuracy == -1 or max_accuracy < accuracy:
                    self.save(iter_, save_dir)
                    self.logger.info("Save process")
                    max_accuracy = accuracy
                    max_acc_iter = iter_
                print('max_accuracy:', max_accuracy)
                self.logger.info('max_accuracy: ' + str(max_accuracy))
                self.logger.info('max_acc_iter: ' + str(max_acc_iter))
        self.logger.info("Model training ends")

    def evaluate(self,
                 print_option=False,
                 include_discrete=False,
                 eps=1e-8,
                 nsample=1024):
        if include_discrete:
            total_mean, total_std, latent_cat = self.get_latent_total()
            return DisentanglemetricFactorJointMask(
                mean=total_mean,
                std=total_std,
                latent_cat=latent_cat,
                nclasses=self.dataset.latents_sizes,
                sampler=self.dataset.next_batch_latent_fix_idx,
                print_option=print_option,
                ignore_discrete=False)
        else:
            total_mean, total_std = self.get_mean_std()
            return DisentanglemetricFactorMask(
                mean=total_mean,
                std=total_std,
                nclasses=self.dataset.latents_sizes,
                sampler=self.dataset.next_batch_latent_fix_idx,
                print_option=print_option)

    def get_mean_std(self):
        total_mean, total_std = apply_tf_op_multi_output(
            inputs=self.image,
            session=self.sess,
            input_gate=self.input1,
            output_gate_list=[self.mean_total, self.stddev_total],
            batch_size=self.args.nbatch,
            train_gate=self.istrain)
        return total_mean, total_std

    def get_latent_total(self):
        total_mean, total_std = self.get_mean_std()
        unary = np.zeros([self.ndata, self.args.ncat])
        for idx in range(self.args.ncat):
            unary[:, idx] = apply_tf_op_multi_input(
                inputs_list=[
                    self.image,
                    np.zeros([self.ndata, self.args.nconti]),
                    np.tile(np.reshape(np.eye(self.args.ncat)[idx], [1, -1]),
                            [self.ndata, 1])
                ],
                session=self.sess,
                input_gate_list=[
                    self.input1, self.epsilon_input, self.objective
                ],
                output_gate=self.rec_cost_vector,
                batch_size=self.args.nbatch,
                train_gate=self.istrain)
        latent_cat = np_softmax(-unary)
        return total_mean, total_std, latent_cat

    def latent_traversal_gif(self,
                             path,
                             include_discrete=False,
                             nimage=50,
                             nmin=-1.0,
                             nmax=1.0):
        gif = list()
        for i in range(nimage):
            value = nmin + (nmax - nmin) * i / nimage
            latent_conti = value * np.eye(self.args.nconti)
            if include_discrete:
                latent_cat = np.eye(self.args.ncat)
                gif.append(
                    matrix_image2big_image(
                        np.concatenate([
                            np.expand_dims(
                                self.decode(latent_input=np.concatenate([
                                    latent_conti,
                                    np.tile(
                                        np.expand_dims(latent_cat[j], axis=0),
                                        [self.args.nconti, 1])
                                ],
                                                                        axis=1)
                                            ),
                                axis=0) for j in range(self.args.ncat)
                        ],
                                       axis=0)))
            else:
                latent_cat = np.zeros([self.args.ncat])
                gif.append(
                    matrix_image2big_image(
                        np.expand_dims(
                            self.decode(latent_input=np.concatenate([
                                latent_conti,
                                np.tile(np.expand_dims(latent_cat, axis=0),
                                        [self.args.nconti, 1])
                            ],
                                                                    axis=1)),
                            axis=0)))
        write_gif(content=gif, path=path)
    def build(self):
        self.logger.info("Model building starts")
        tf.reset_default_graph()
        tf.set_random_seed(self.args.rseed)

        self.input1 = tf.placeholder(
            tf.float32,
            shape=[self.args.nbatch, self.height, self.width, self.nchannel])
        self.epsilon_input = tf.placeholder(
            tf.float32, shape=[self.args.nbatch, self.args.nconti])
        self.objective = tf.placeholder(
            tf.float32, shape=[self.args.nbatch, self.args.ncat])
        self.istrain = tf.placeholder(tf.bool, shape=[])
        self.I_weight = tf.placeholder(tf.float32, shape=[])
        self.F_weight = tf.placeholder(tf.float32, shape=[])

        # For VC-Loss
        self.delta_dim = tf.placeholder(tf.int32, shape=[self.args.nbatch])
        if self.args.use_discrete:
            self.objective_2_idx = tf.placeholder(tf.int32,
                                                  shape=[self.args.nbatch])
        else:
            self.objective_2 = tf.placeholder(
                tf.float32, shape=[self.args.nbatch, self.args.ncat])

        self.generate_sess()

        self.mcf = SolveMaxMatching(nworkers=self.args.nbatch,
                                    ntasks=self.args.ncat,
                                    k=1,
                                    pairwise_lamb=self.args.plamb)
        # Encoding
        self.encoder_net = encoder1_64
        self.decoder_net = decoder1_64
        self.disc_net = disc_net_64

        # Continuous rep
        self.mean_total, self.stddev_total = tf.split(self.encoder_net(
            self.input1,
            output_dim=2 * self.args.nconti,
            scope='encoder',
            reuse=False)['output'],
                                                      num_or_size_splits=2,
                                                      axis=1)
        self.stddev_total = tf.nn.softplus(self.stddev_total)
        self.z_sample = tf.add(
            self.mean_total, tf.multiply(self.stddev_total,
                                         self.epsilon_input))

        # For VC-Loss
        if self.args.delta_type == 'onedim':
            # C_delta_latents = tf.random.uniform([minibatch_size], minval=0, maxval=C_global_size, dtype=tf.int32)
            # C_delta_latents = tf.cast(tf.one_hot(C_delta_latents, C_global_size), latents.dtype)
            self.z_delta = tf.cast(
                tf.one_hot(self.delta_dim, self.args.nconti),
                self.z_sample.dtype)
            rand_eps = tf.random.normal([self.args.nbatch, 1],
                                        mean=0.0,
                                        stddev=2.0)
            self.delta_target = self.z_delta * rand_eps
            self.z_added = self.delta_target
            self.z_added = self.z_added + self.z_sample
        elif self.args.delta_type == 'fulldim':
            # C_delta_latents = tf.random.uniform([minibatch_size, C_global_size], minval=0, maxval=1.0, dtype=latents.dtype)
            self.delta_target = tf.random.uniform(
                [self.args.nbatch, self.args.nconti],
                minval=0,
                maxval=1.0,
                dtype=self.z_sample.dtype)
            self.z_added = (self.delta_target - 0.5) * self.args.vc_epsilon
            self.z_added = self.z_added + self.z_sample

        self.dec_output_dict = self.decoder_net(z=tf.concat(
            [self.z_sample, self.objective], axis=-1),
                                                output_channel=self.nchannel,
                                                scope="decoder",
                                                reuse=False)
        self.dec_output = self.dec_output_dict['output']
        self.feat_output = self.dec_output_dict['deconv2d2']
        self.F_loss = tf.reduce_mean(self.feat_output * self.feat_output)
        self.F_loss = self.args.F_beta * self.F_loss

        if self.args.use_discrete:
            self.objective_2 = tf.cast(
                tf.one_hot(self.objective_2_idx, self.args.ncat),
                self.z_added.dtype)
        self.dec_output_2 = self.decoder_net(z=tf.concat(
            [self.z_added, self.objective_2], axis=-1),
                                             output_channel=self.nchannel,
                                             scope="decoder",
                                             reuse=True)['output']
        self.disc_output = self.disc_net(img1=self.dec_output,
                                         img2=self.dec_output_2,
                                         target_dim=self.args.nconti,
                                         scope='discriminator',
                                         reuse=False)['output']

        if self.args.delta_type == 'onedim':
            # Loss VC CEloss
            self.disc_prob = tf.nn.softmax(self.disc_output, axis=1)
            self.I_loss = tf.reduce_mean(
                tf.reduce_sum(self.z_delta * tf.log(self.disc_prob + 1e-12),
                              axis=1))
            self.I_loss = -self.args.C_lambda * self.I_loss
        elif self.args.delta_type == 'fulldim':
            # Loss VC MSEloss
            self.I_loss = tf.reduce_mean(
                tf.reduce_sum(
                    (tf.nn.sigmoid(self.disc_output) - self.delta_target)**2,
                    axis=1))
            self.I_loss = self.args.C_lambda * self.I_loss

        # Unary vector
        self.rec_cost_vector = sigmoid_cross_entropy_without_mean(
            labels=self.input1, logits=self.dec_output)

        # Loss
        self.rec_cost = tf.reduce_mean(self.rec_cost_vector)

        weight = tf.constant(np.array(self.args.nconti * [self.args.beta_max]),
                             dtype=tf.float32)
        kl_cost = vae_kl_cost_weight(mean=self.mean_total,
                                     stddev=self.stddev_total,
                                     weight=weight)
        self.loss = self.rec_cost+kl_cost+tf.losses.get_regularization_loss()+\
                self.I_loss*self.I_weight+self.F_loss*self.F_weight

        tf.summary.scalar('rec_loss', self.rec_cost)
        tf.summary.scalar('I_loss', self.I_loss)
        tf.summary.scalar('F_loss', self.F_loss)
        self.merged = tf.summary.merge_all()

        # Decode
        self.latent_ph = tf.placeholder(
            tf.float32,
            shape=[self.args.nbatch, self.args.nconti + self.args.ncat])
        self.dec_output_ph = tf.nn.sigmoid(
            self.decoder_net(z=self.latent_ph,
                             output_channel=self.nchannel,
                             scope="decoder",
                             reuse=True)['output'])

        # Free Batch Decode
        self.free_latent_ph = tf.placeholder(
            tf.float32, shape=[None, self.args.nconti + self.args.ncat])
        self.free_dec_output_ph = tf.nn.sigmoid(
            self.decoder_net(z=self.free_latent_ph,
                             output_channel=self.nchannel,
                             scope="decoder",
                             reuse=True)['output'])

        self.logger.info("Model building ends")
class Model(ModelPlugin):
    def __init__(self, dataset, logfilepath, args):
        super().__init__(dataset, logfilepath, args)
        self.build()

    def build(self):
        self.logger.info("Model building starts")
        tf.reset_default_graph()
        tf.set_random_seed(self.args.rseed)

        self.input1 = tf.placeholder(
            tf.float32,
            shape=[self.args.nbatch, self.height, self.width, self.nchannel])
        self.epsilon_input = tf.placeholder(
            tf.float32, shape=[self.args.nbatch, self.args.nconti])
        self.objective = tf.placeholder(
            tf.float32, shape=[self.args.nbatch, self.args.ncat])
        self.istrain = tf.placeholder(tf.bool, shape=[])
        self.I_weight = tf.placeholder(tf.float32, shape=[])
        self.F_weight = tf.placeholder(tf.float32, shape=[])

        # For VC-Loss
        self.delta_dim = tf.placeholder(tf.int32, shape=[self.args.nbatch])
        if self.args.use_discrete:
            self.objective_2_idx = tf.placeholder(tf.int32,
                                                  shape=[self.args.nbatch])
        else:
            self.objective_2 = tf.placeholder(
                tf.float32, shape=[self.args.nbatch, self.args.ncat])

        self.generate_sess()

        self.mcf = SolveMaxMatching(nworkers=self.args.nbatch,
                                    ntasks=self.args.ncat,
                                    k=1,
                                    pairwise_lamb=self.args.plamb)
        # Encoding
        self.encoder_net = encoder1_64
        self.decoder_net = decoder1_64
        self.disc_net = disc_net_64

        # Continuous rep
        self.mean_total, self.stddev_total = tf.split(self.encoder_net(
            self.input1,
            output_dim=2 * self.args.nconti,
            scope='encoder',
            reuse=False)['output'],
                                                      num_or_size_splits=2,
                                                      axis=1)
        self.stddev_total = tf.nn.softplus(self.stddev_total)
        self.z_sample = tf.add(
            self.mean_total, tf.multiply(self.stddev_total,
                                         self.epsilon_input))

        # For VC-Loss
        if self.args.delta_type == 'onedim':
            # C_delta_latents = tf.random.uniform([minibatch_size], minval=0, maxval=C_global_size, dtype=tf.int32)
            # C_delta_latents = tf.cast(tf.one_hot(C_delta_latents, C_global_size), latents.dtype)
            self.z_delta = tf.cast(
                tf.one_hot(self.delta_dim, self.args.nconti),
                self.z_sample.dtype)
            rand_eps = tf.random.normal([self.args.nbatch, 1],
                                        mean=0.0,
                                        stddev=2.0)
            self.delta_target = self.z_delta * rand_eps
            self.z_added = self.delta_target
            self.z_added = self.z_added + self.z_sample
        elif self.args.delta_type == 'fulldim':
            # C_delta_latents = tf.random.uniform([minibatch_size, C_global_size], minval=0, maxval=1.0, dtype=latents.dtype)
            self.delta_target = tf.random.uniform(
                [self.args.nbatch, self.args.nconti],
                minval=0,
                maxval=1.0,
                dtype=self.z_sample.dtype)
            self.z_added = (self.delta_target - 0.5) * self.args.vc_epsilon
            self.z_added = self.z_added + self.z_sample

        self.dec_output_dict = self.decoder_net(z=tf.concat(
            [self.z_sample, self.objective], axis=-1),
                                                output_channel=self.nchannel,
                                                scope="decoder",
                                                reuse=False)
        self.dec_output = self.dec_output_dict['output']
        self.feat_output = self.dec_output_dict['deconv2d2']
        self.F_loss = tf.reduce_mean(self.feat_output * self.feat_output)
        self.F_loss = self.args.F_beta * self.F_loss

        if self.args.use_discrete:
            self.objective_2 = tf.cast(
                tf.one_hot(self.objective_2_idx, self.args.ncat),
                self.z_added.dtype)
        self.dec_output_2 = self.decoder_net(z=tf.concat(
            [self.z_added, self.objective_2], axis=-1),
                                             output_channel=self.nchannel,
                                             scope="decoder",
                                             reuse=True)['output']
        self.disc_output = self.disc_net(img1=self.dec_output,
                                         img2=self.dec_output_2,
                                         target_dim=self.args.nconti,
                                         scope='discriminator',
                                         reuse=False)['output']

        if self.args.delta_type == 'onedim':
            # Loss VC CEloss
            self.disc_prob = tf.nn.softmax(self.disc_output, axis=1)
            self.I_loss = tf.reduce_mean(
                tf.reduce_sum(self.z_delta * tf.log(self.disc_prob + 1e-12),
                              axis=1))
            self.I_loss = -self.args.C_lambda * self.I_loss
        elif self.args.delta_type == 'fulldim':
            # Loss VC MSEloss
            self.I_loss = tf.reduce_mean(
                tf.reduce_sum(
                    (tf.nn.sigmoid(self.disc_output) - self.delta_target)**2,
                    axis=1))
            self.I_loss = self.args.C_lambda * self.I_loss

        # Unary vector
        self.rec_cost_vector = sigmoid_cross_entropy_without_mean(
            labels=self.input1, logits=self.dec_output)

        # Loss
        self.rec_cost = tf.reduce_mean(self.rec_cost_vector)

        weight = tf.constant(np.array(self.args.nconti * [self.args.beta_max]),
                             dtype=tf.float32)
        kl_cost = vae_kl_cost_weight(mean=self.mean_total,
                                     stddev=self.stddev_total,
                                     weight=weight)
        self.loss = self.rec_cost+kl_cost+tf.losses.get_regularization_loss()+\
                self.I_loss*self.I_weight+self.F_loss*self.F_weight

        tf.summary.scalar('rec_loss', self.rec_cost)
        tf.summary.scalar('I_loss', self.I_loss)
        tf.summary.scalar('F_loss', self.F_loss)
        self.merged = tf.summary.merge_all()

        # Decode
        self.latent_ph = tf.placeholder(
            tf.float32,
            shape=[self.args.nbatch, self.args.nconti + self.args.ncat])
        self.dec_output_ph = tf.nn.sigmoid(
            self.decoder_net(z=self.latent_ph,
                             output_channel=self.nchannel,
                             scope="decoder",
                             reuse=True)['output'])

        # Free Batch Decode
        self.free_latent_ph = tf.placeholder(
            tf.float32, shape=[None, self.args.nconti + self.args.ncat])
        self.free_dec_output_ph = tf.nn.sigmoid(
            self.decoder_net(z=self.free_latent_ph,
                             output_channel=self.nchannel,
                             scope="decoder",
                             reuse=True)['output'])

        self.logger.info("Model building ends")

    def decode(self, latent_input):
        return apply_tf_op(inputs=latent_input,
                           session=self.sess,
                           input_gate=self.latent_ph,
                           output_gate=self.dec_output_ph,
                           batch_size=self.args.nbatch)

    def set_up_train(self):
        self.logger.info("Model setting up train starts")

        if not hasattr(self, 'start_iter'): self.start_iter = 0
        self.logger.info("Start iter: {}".format(self.start_iter))

        decay_func = DECAY_DICT[self.args.dtype]
        decay_params = DECAY_PARAMS_DICT[self.args.dtype][self.args.nbatch][
            self.args.dptype].copy()
        decay_params['initial_step'] = self.start_iter

        self.lr, update_step_op = decay_func(**decay_params)
        self.update_step_op = [update_step_op]

        var_list = [v for v in tf.trainable_variables() if 'encoder' in v.name] + \
                [v for v in tf.trainable_variables() if 'decoder' in v.name] + \
                [v for v in tf.trainable_variables() if 'discriminator' in v.name]

        with tf.control_dependencies(tf.get_collection("update_ops")):
            self.train_op = get_train_op_v2(tf.train.AdamOptimizer(
                learning_rate=self.lr, beta1=0.9, beta2=0.999),
                                            loss=self.loss,
                                            var_list=var_list)

        self.logger.info("Model setting up train ends")

    def run_batch(self, train_idx):
        feed_dict = dict()
        feed_dict[self.input1] = self.dataset.next_batch(
            batch_size=self.args.nbatch)[0]
        feed_dict[self.istrain] = True
        feed_dict[self.epsilon_input] = np.random.normal(
            size=[self.args.nbatch, self.args.nconti])

        # For VC-Loss
        feed_dict[self.delta_dim] = np.random.randint(0,
                                                      self.args.nconti,
                                                      size=[self.args.nbatch])
        # feed_dict[self.objective_2_idx] = np.random.randint(0, self.args.ncat, size=[self.args.nbatch])
        feed_dict[self.objective_2] = np.zeros(
            [self.args.nbatch, self.args.ncat])

        if self.args.use_discrete:
            # with discrete
            if train_idx < self.args.ntime:
                feed_dict[self.objective] = np.zeros(
                    [self.args.nbatch, self.args.ncat])
                feed_dict[self.I_weight] = 1.
                feed_dict[self.F_weight] = 1.
            else:
                unary = np.zeros([self.args.nbatch, self.args.ncat])
                for idx in range(self.args.ncat):
                    feed_dict[self.objective] = np.tile(
                        np.reshape(np.eye(self.args.ncat)[idx], [1, -1]),
                        [self.args.nbatch, 1])
                    unary[:, idx] = self.sess.run(self.rec_cost_vector,
                                                  feed_dict=feed_dict)
                feed_dict[self.objective] = self.mcf.solve(-unary)[1]
                feed_dict[self.I_weight] = 1.
                feed_dict[self.F_weight] = 1.
        else:
            # no discrete
            feed_dict[self.objective] = np.zeros(
                [self.args.nbatch, self.args.ncat])
            if train_idx < self.args.ntime:
                feed_dict[self.I_weight] = 1.
                feed_dict[self.F_weight] = 1.
            else:
                feed_dict[self.I_weight] = 1.
                feed_dict[self.F_weight] = 1.

        summary, _ = self.sess.run([self.merged, self.train_op],
                                   feed_dict=feed_dict)
        return summary

    def train(self, niter, piter, siter, save_dir=None, asset_dir=None):
        self.logger.info("Model training starts")
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        train_log_dir = os.path.join(asset_dir, current_time, 'train')
        # test_log_dir = os.path.join(asset_dir, current_time, '/test')
        train_summary_writer = tf.summary.FileWriter(train_log_dir,
                                                     self.sess.graph)

        final_iter = self.start_iter + niter
        max_accuracy = -1
        max_acc_iter = -1

        for iter_ in tqdm_range(self.start_iter, final_iter):
            train_idx = (iter_ - self.start_iter) // piter
            summary = self.run_batch(train_idx)
            train_summary_writer.add_summary(summary, iter_)

            if (iter_ + 1) % siter == 0 or iter_ + 1 == final_iter:
                if self.args.use_discrete:
                    include_discrete = False if train_idx < self.args.ntime else True
                else:
                    include_discrete = False
                accuracy = self.evaluate(include_discrete=include_discrete)

                self.latent_traversal_gif(path=asset_dir +
                                          '{}.gif'.format(iter_ + 1),
                                          include_discrete=include_discrete)
                if max_accuracy == -1 or max_accuracy < accuracy:
                    self.save(iter_, save_dir)
                    self.logger.info("Save process")
                    max_accuracy = accuracy
                    max_acc_iter = iter_
                self.logger.info('max_accuracy: ' + str(max_accuracy))
                self.logger.info('max_acc_iter: ' + str(max_acc_iter))

                with open(os.path.join(asset_dir, 'acc.txt'), 'a') as f:
                    f.write('iter: '+str(iter_) + ', acc: ' + str(accuracy) + \
                            '; max_iter:' + str(max_acc_iter) + \
                            ', max_acc:' + str(max_accuracy))
        self.logger.info("Model training ends")

    def evaluate(self,
                 print_option=False,
                 include_discrete=False,
                 eps=1e-8,
                 nsample=1024):
        if include_discrete:
            total_mean, total_std, latent_cat = self.get_latent_total()
            return DisentanglemetricFactorJointMask(
                mean=total_mean,
                std=total_std,
                latent_cat=latent_cat,
                nclasses=self.dataset.latents_sizes,
                sampler=self.dataset.next_batch_latent_fix_idx,
                print_option=print_option,
                ignore_discrete=False)
        else:
            total_mean, total_std = self.get_mean_std()
            return DisentanglemetricFactorMask(
                mean=total_mean,
                std=total_std,
                nclasses=self.dataset.latents_sizes,
                sampler=self.dataset.next_batch_latent_fix_idx,
                print_option=print_option)

    def get_mean_std(self):
        total_mean, total_std = apply_tf_op_multi_output(
            inputs=self.image,
            session=self.sess,
            input_gate=self.input1,
            output_gate_list=[self.mean_total, self.stddev_total],
            batch_size=self.args.nbatch,
            train_gate=self.istrain)
        return total_mean, total_std

    def get_latent_total(self):
        total_mean, total_std = self.get_mean_std()
        unary = np.zeros([self.ndata, self.args.ncat])
        for idx in range(self.args.ncat):
            unary[:, idx] = apply_tf_op_multi_input(
                inputs_list=[
                    self.image,
                    np.zeros([self.ndata, self.args.nconti]),
                    np.tile(np.reshape(np.eye(self.args.ncat)[idx], [1, -1]),
                            [self.ndata, 1])
                ],
                session=self.sess,
                input_gate_list=[
                    self.input1, self.epsilon_input, self.objective
                ],
                output_gate=self.rec_cost_vector,
                batch_size=self.args.nbatch,
                train_gate=self.istrain)
        latent_cat = np_softmax(-unary)
        return total_mean, total_std, latent_cat

    def latent_traversal_gif(self,
                             path,
                             include_discrete=False,
                             nimage=50,
                             nmin=-1.0,
                             nmax=1.0):
        gif = list()
        for i in range(nimage):
            value = nmin + (nmax - nmin) * i / nimage
            latent_conti = value * np.eye(self.args.nconti)
            if include_discrete:
                latent_cat = np.eye(self.args.ncat)
                gif.append(
                    matrix_image2big_image(
                        np.concatenate([
                            np.expand_dims(
                                self.decode(latent_input=np.concatenate([
                                    latent_conti,
                                    np.tile(
                                        np.expand_dims(latent_cat[j], axis=0),
                                        [self.args.nconti, 1])
                                ],
                                                                        axis=1)
                                            ),
                                axis=0) for j in range(self.args.ncat)
                        ],
                                       axis=0)))
            else:
                latent_cat = np.zeros([self.args.ncat])
                gif.append(
                    matrix_image2big_image(
                        np.expand_dims(
                            self.decode(latent_input=np.concatenate([
                                latent_conti,
                                np.tile(np.expand_dims(latent_cat, axis=0),
                                        [self.args.nconti, 1])
                            ],
                                                                    axis=1)),
                            axis=0)))
        write_gif(content=gif, path=path)

    def generate_image_pairs(self,
                             batch_size,
                             asset_dir,
                             n_pairs,
                             include_discrete=True):
        n_rounds = n_pairs // batch_size
        pairs_path = os.path.join(asset_dir, 'pairs_dataset')
        if not os.path.exists(pairs_path):
            os.makedirs(pairs_path)
        for i in range(n_rounds):
            # z_1 = np.random.normal(size=[batch_size, self.args.nconti])
            # z_2 = np.random.normal(size=[batch_size, self.args.nconti])
            z_1 = np.random.uniform(low=-2,
                                    high=2,
                                    size=[batch_size, self.args.nconti])
            z_2 = np.random.uniform(low=-2,
                                    high=2,
                                    size=[batch_size, self.args.nconti])
            if self.args.latent_type == 'onedim':
                delta_dim = np.random.randint(0,
                                              self.args.nconti,
                                              size=[batch_size])
                delta_onehot = np.zeros((batch_size, self.args.nconti))
                delta_onehot[np.arange(delta_dim.size), delta_dim] = 1
                z_2 = np.where(delta_onehot > 0, z_2, z_1)
            delta_z = z_1 - z_2
            if i == 0:
                labels = delta_z
            else:
                labels = np.concatenate([labels, delta_z], axis=0)

            if include_discrete:
                cat_dim = np.random.randint(0,
                                            self.args.ncat,
                                            size=[batch_size])
                cat_onehot = np.zeros((batch_size, self.args.ncat))
                cat_onehot[np.arange(cat_dim.size), cat_dim] = 1
            else:
                cat_onehot = np.zeros((batch_size, self.args.ncat))
            img_1 = self.sess.run(self.free_dec_output_ph,
                                  feed_dict={
                                      self.free_latent_ph:
                                      np.concatenate([z_1, cat_onehot], axis=1)
                                  })
            img_2 = self.sess.run(self.free_dec_output_ph,
                                  feed_dict={
                                      self.free_latent_ph:
                                      np.concatenate([z_2, cat_onehot], axis=1)
                                  })
            # [b, h, w, c]
            for j in range(img_1.shape[0]):
                pair_np = np.concatenate([img_1[j], img_2[j]], axis=1)
                # pair_np = (pair_np * 255).astype(np.uint8)
                pair_np = (np.squeeze(pair_np) * 255).astype(np.uint8)
                img = Image.fromarray(pair_np)
                img.save(
                    os.path.join(pairs_path,
                                 'pair_%06d.jpg' % (i * batch_size + j)))
            # if i == 0:
            # imgs = np.concatenate([img_1, img_2], axis=2)
            # else:
            # imgs_i = np.concatenate([img_1, img_2], axis=2)
            # imgs = np.concatenate([imgs, imgs_i], axis=0)

        write_npy(labels, os.path.join(pairs_path, 'labels.npy'))
示例#5
0
    def build(self):
        self.logger.info("Model building starts")
        tf.reset_default_graph()
        tf.set_random_seed(self.args.rseed)

        self.input1 = tf.placeholder(
            tf.float32,
            shape=[self.args.nbatch, self.height, self.width, self.nchannel])
        self.epsilon_input = tf.placeholder(
            tf.float32, shape=[self.args.nbatch, self.args.nconti])
        self.objective = tf.placeholder(
            tf.float32, shape=[self.args.nbatch, self.args.ncat])
        self.istrain = tf.placeholder(tf.bool, shape=[])

        self.generate_sess()

        self.mcf = SolveMaxMatching(nworkers=self.args.nbatch,
                                    ntasks=self.args.ncat,
                                    k=1,
                                    pairwise_lamb=self.args.plamb)
        # Encoding
        self.encoder_net = encoder1_64
        self.decoder_net = decoder1_64

        # Continuous rep
        self.mean_total, self.stddev_total = tf.split(self.encoder_net(
            self.input1,
            output_dim=2 * self.args.nconti,
            scope='encoder',
            reuse=False)['output'],
                                                      num_or_size_splits=2,
                                                      axis=1)
        self.stddev_total = tf.nn.softplus(self.stddev_total)
        self.z_sample = tf.add(
            self.mean_total, tf.multiply(self.stddev_total,
                                         self.epsilon_input))
        self.dec_output = self.decoder_net(z=tf.concat(
            [self.z_sample, self.objective], axis=-1),
                                           output_channel=self.nchannel,
                                           scope="decoder",
                                           reuse=False)['output']

        # Unary vector
        self.rec_cost_vector = sigmoid_cross_entropy_without_mean(
            labels=self.input1, logits=self.dec_output)

        # Loss
        self.rec_cost = tf.reduce_mean(self.rec_cost_vector)

        self.loss_dict = dict()
        for idx in range(self.args.nconti + 1):
            weight = tf.constant(
                np.array(idx * [self.args.beta_min] +
                         (self.args.nconti - idx) * [self.args.beta_max]),
                dtype=tf.float32)
            kl_cost = vae_kl_cost_weight(mean=self.mean_total,
                                         stddev=self.stddev_total,
                                         weight=weight)
            self.loss_dict[
                idx] = self.rec_cost + kl_cost + tf.losses.get_regularization_loss(
                )

        # Decode
        self.latent_ph = tf.placeholder(
            tf.float32,
            shape=[self.args.nbatch, self.args.nconti + self.args.ncat])
        self.dec_output_ph = tf.nn.sigmoid(
            self.decoder_net(z=self.latent_ph,
                             output_channel=self.nchannel,
                             scope="decoder",
                             reuse=True)['output'])

        self.logger.info("Model building ends")
示例#6
0
class Model(ModelPlugin):
    def __init__(self, dataset, logfilepath, args):
        super().__init__(dataset, logfilepath, args)
        self.build()

    def build(self):
        self.logger.info("Model building starts")
        tf.reset_default_graph()
        tf.set_random_seed(self.args.rseed)

        self.input1 = tf.placeholder(
            tf.float32,
            shape=[self.args.nbatch, self.height, self.width, self.nchannel])
        self.epsilon_input = tf.placeholder(
            tf.float32, shape=[self.args.nbatch, self.args.nconti])
        self.objective = tf.placeholder(
            tf.float32, shape=[self.args.nbatch, self.args.ncat])
        self.istrain = tf.placeholder(tf.bool, shape=[])

        self.generate_sess()

        self.mcf = SolveMaxMatching(nworkers=self.args.nbatch,
                                    ntasks=self.args.ncat,
                                    k=1,
                                    pairwise_lamb=self.args.plamb)
        # Encoding
        self.encoder_net = encoder1_64
        self.decoder_net = decoder1_64

        # Continuous rep
        self.mean_total, self.stddev_total = tf.split(self.encoder_net(
            self.input1,
            output_dim=2 * self.args.nconti,
            scope='encoder',
            reuse=False)['output'],
                                                      num_or_size_splits=2,
                                                      axis=1)
        self.stddev_total = tf.nn.softplus(self.stddev_total)
        self.z_sample = tf.add(
            self.mean_total, tf.multiply(self.stddev_total,
                                         self.epsilon_input))
        self.dec_output = self.decoder_net(z=tf.concat(
            [self.z_sample, self.objective], axis=-1),
                                           output_channel=self.nchannel,
                                           scope="decoder",
                                           reuse=False)['output']

        # Unary vector
        self.rec_cost_vector = sigmoid_cross_entropy_without_mean(
            labels=self.input1, logits=self.dec_output)

        # Loss
        self.rec_cost = tf.reduce_mean(self.rec_cost_vector)

        self.loss_dict = dict()
        for idx in range(self.args.nconti + 1):
            weight = tf.constant(
                np.array(idx * [self.args.beta_min] +
                         (self.args.nconti - idx) * [self.args.beta_max]),
                dtype=tf.float32)
            kl_cost = vae_kl_cost_weight(mean=self.mean_total,
                                         stddev=self.stddev_total,
                                         weight=weight)
            self.loss_dict[
                idx] = self.rec_cost + kl_cost + tf.losses.get_regularization_loss(
                )

        # Decode
        self.latent_ph = tf.placeholder(
            tf.float32,
            shape=[self.args.nbatch, self.args.nconti + self.args.ncat])
        self.dec_output_ph = tf.nn.sigmoid(
            self.decoder_net(z=self.latent_ph,
                             output_channel=self.nchannel,
                             scope="decoder",
                             reuse=True)['output'])

        self.logger.info("Model building ends")

    def decode(self, latent_input):
        return apply_tf_op(inputs=latent_input,
                           session=self.sess,
                           input_gate=self.latent_ph,
                           output_gate=self.dec_output_ph,
                           batch_size=self.args.nbatch)

    def set_up_train(self):
        self.logger.info("Model setting up train starts")

        if not hasattr(self, 'start_iter'): self.start_iter = 0
        self.logger.info("Start iter: {}".format(self.start_iter))

        decay_func = DECAY_DICT[self.args.dtype]
        decay_params = DECAY_PARAMS_DICT[self.args.dtype][self.args.nbatch][
            self.args.dptype].copy()
        decay_params['initial_step'] = self.start_iter

        self.lr, update_step_op = decay_func(**decay_params)
        self.update_step_op = [update_step_op]

        var_list = [
            v for v in tf.trainable_variables() if 'encoder' in v.name
        ] + [v for v in tf.trainable_variables() if 'decoder' in v.name]

        self.train_op_dict = dict()
        with tf.control_dependencies(tf.get_collection("update_ops")):
            for idx in range(self.args.nconti + 1):
                self.train_op_dict[idx] = get_train_op_v2(
                    tf.train.AdamOptimizer(learning_rate=self.lr,
                                           beta1=0.9,
                                           beta2=0.999),
                    loss=self.loss_dict[idx],
                    var_list=var_list)

        self.logger.info("Model setting up train ends")

    def run_batch(self, train_idx):
        feed_dict = dict()
        feed_dict[self.input1] = self.dataset.next_batch(
            batch_size=self.args.nbatch)[0]
        feed_dict[self.istrain] = True
        feed_dict[self.epsilon_input] = np.random.normal(
            size=[self.args.nbatch, self.args.nconti])

        if train_idx < self.args.ntime:
            feed_dict[self.objective] = np.zeros(
                [self.args.nbatch, self.args.ncat])
        else:
            unary = np.zeros([self.args.nbatch, self.args.ncat])
            for idx in range(self.args.ncat):
                feed_dict[self.objective] = np.tile(
                    np.reshape(np.eye(self.args.ncat)[idx], [1, -1]),
                    [self.args.nbatch, 1])
                unary[:, idx] = self.sess.run(self.rec_cost_vector,
                                              feed_dict=feed_dict)
            feed_dict[self.objective] = self.mcf.solve(-unary)[1]

        if train_idx >= self.args.ntime:
            idx = min(train_idx, self.args.nconti)
        else:
            idx = min(train_idx + 1, self.args.nconti)
        self.sess.run(self.train_op_dict[idx], feed_dict=feed_dict)

    def train(self, niter, piter, siter, save_dir=None, asset_dir=None):
        self.logger.info("Model training starts")

        final_iter = self.start_iter + niter
        max_accuracy = -1

        for iter_ in tqdm_range(self.start_iter, final_iter):
            train_idx = (iter_ - self.start_iter) // piter
            self.run_batch(train_idx)

            if (iter_ + 1) % siter == 0 or iter_ + 1 == final_iter:
                include_discrete = False if train_idx < self.args.ntime else True
                accuracy = self.evaluate(include_discrete=include_discrete)

                self.latent_traversal_gif(path=asset_dir +
                                          '{}.gif'.format(iter_ + 1),
                                          include_discrete=include_discrete)
                if max_accuracy == -1 or max_accuracy < accuracy:
                    self.save(iter_, save_dir)
                    self.logger.info("Save process")
                    max_accuracy = accuracy
        self.logger.info("Model training ends")

    def evaluate(self,
                 print_option=False,
                 include_discrete=False,
                 eps=1e-8,
                 nsample=1024):
        if include_discrete:
            total_mean, total_std, latent_cat = self.get_latent_total()
            return DisentanglemetricFactorJointMask(
                mean=total_mean,
                std=total_std,
                latent_cat=latent_cat,
                nclasses=self.dataset.latents_sizes,
                sampler=self.dataset.next_batch_latent_fix_idx,
                print_option=print_option,
                ignore_discrete=False)
        else:
            total_mean, total_std = self.get_mean_std()
            return DisentanglemetricFactorMask(
                mean=total_mean,
                std=total_std,
                nclasses=self.dataset.latents_sizes,
                sampler=self.dataset.next_batch_latent_fix_idx,
                print_option=print_option)

    def get_mean_std(self):
        total_mean, total_std = apply_tf_op_multi_output(
            inputs=self.image,
            session=self.sess,
            input_gate=self.input1,
            output_gate_list=[self.mean_total, self.stddev_total],
            batch_size=self.args.nbatch,
            train_gate=self.istrain)
        return total_mean, total_std

    def get_latent_total(self):
        total_mean, total_std = self.get_mean_std()
        unary = np.zeros([self.ndata, self.args.ncat])
        for idx in range(self.args.ncat):
            unary[:, idx] = apply_tf_op_multi_input(
                inputs_list=[
                    self.image,
                    np.zeros([self.ndata, self.args.nconti]),
                    np.tile(np.reshape(np.eye(self.args.ncat)[idx], [1, -1]),
                            [self.ndata, 1])
                ],
                session=self.sess,
                input_gate_list=[
                    self.input1, self.epsilon_input, self.objective
                ],
                output_gate=self.rec_cost_vector,
                batch_size=self.args.nbatch,
                train_gate=self.istrain)
        latent_cat = np_softmax(-unary)
        return total_mean, total_std, latent_cat

    def latent_traversal_gif(self,
                             path,
                             include_discrete=False,
                             nimage=50,
                             nmin=-1.0,
                             nmax=1.0):
        gif = list()
        for i in range(nimage):
            value = nmin + (nmax - nmin) * i / nimage
            latent_conti = value * np.eye(self.args.nconti)
            if include_discrete:
                latent_cat = np.eye(self.args.ncat)
                gif.append(
                    matrix_image2big_image(
                        np.concatenate([
                            np.expand_dims(
                                self.decode(latent_input=np.concatenate([
                                    latent_conti,
                                    np.tile(
                                        np.expand_dims(latent_cat[j], axis=0),
                                        [self.args.nconti, 1])
                                ],
                                                                        axis=1)
                                            ),
                                axis=0) for j in range(self.args.ncat)
                        ],
                                       axis=0)))
            else:
                latent_cat = np.zeros([self.args.ncat])
                gif.append(
                    matrix_image2big_image(
                        np.expand_dims(
                            self.decode(latent_input=np.concatenate([
                                latent_conti,
                                np.tile(np.expand_dims(latent_cat, axis=0),
                                        [self.args.nconti, 1])
                            ],
                                                                    axis=1)),
                            axis=0)))
        write_gif(content=gif, path=path)
示例#7
0
    def build_hash(self):
        self.logger.info("Model building train hash starts")

        self.mcf = SolveMaxMatching(nworkers=self.args.nsclass,
                                    ntasks=self.args.d,
                                    k=1,
                                    pairwise_lamb=self.args.plamb2)

        with slim.arg_scope(
            [slim.fully_connected],
                activation_fn=None,
                weights_regularizer=slim.l2_regularizer(0.0005),
                biases_initializer=tf.zeros_initializer(),
                weights_initializer=tf.truncated_normal_initializer(0.0,
                                                                    0.01)):
            if self.args.ltype == 'triplet':
                # placeholder list
                self.objective_list = [
                    tf.placeholder(dtype=tf.float32,
                                   shape=[self.args.nbatch, self.args.d],
                                   name="objective%d" % i)
                    for i in range(self.args.k)
                ]
                self.embed_k_hash = self.last
                with tf.variable_scope('Hash', reuse=False):
                    self.embed_k_hash = slim.fully_connected(
                        self.embed_k_hash,
                        self.args.d * self.args.k,
                        scope="fc1")  # [batch, d*k]
                self.embed_k_hash_list = tf.split(
                    self.embed_k_hash, num_or_size_splits=self.args.k,
                    axis=1)  # list(k*[batch, d])
                self.embed_k_hash_l2_norm_list = [
                    tf.nn.l2_normalize(v, dim=-1)
                    for v in self.embed_k_hash_list
                ]  # list(k*[batch,d]), each l2 normalize
                self.pairwise_distance = pairwise_distance_w_obj1

                self.loss_hash = 0
                for idx in range(self.args.k):
                    self.loss_hash += triplet_semihard_loss_hash(
                        labels=self.label_list[idx],
                        embeddings=self.embed_k_hash_l2_norm_list[idx],
                        objectives=self.objective_list[idx],
                        pairwise_distance=self.pairwise_distance,
                        margin=self.args.param)
            else:
                self.objective_list = [
                    tf.placeholder(dtype=tf.float32,
                                   shape=[self.args.nbatch // 2, self.args.d],
                                   name="objective%d" % i)
                    for i in range(self.args.k)
                ]
                self.anc_embed_k_hash = self.anc_last
                self.pos_embed_k_hash = self.pos_last
                with tf.variable_scope('Hash', reuse=False):
                    self.anc_embed_k_hash = slim.fully_connected(
                        self.anc_embed_k_hash,
                        self.args.d * self.args.k,
                        scope="fc1")
                with tf.variable_scope('Hash', reuse=True):
                    self.pos_embed_k_hash = slim.fully_connected(
                        self.pos_embed_k_hash,
                        self.args.d * self.args.k,
                        scope="fc1")
                self.anc_embed_k_hash_list = tf.split(
                    self.anc_embed_k_hash,
                    num_or_size_splits=self.args.k,
                    axis=1)  # list(k*[batch, d])
                self.pos_embed_k_hash_list = tf.split(
                    self.pos_embed_k_hash,
                    num_or_size_splits=self.args.k,
                    axis=1)  # list(k*[batch, d])
                self.similarity_func = pairwise_similarity_w_obj1

                self.loss_hash = 0
                for idx in range(self.args.k):
                    self.loss_hash += npairs_loss_hash(labels=self.label_list[idx], embeddings_anchor=self.anc_embed_k_hash_list[idx], embeddings_positive=self.pos_embed_k_hash_list[idx],\
                                            objective=self.objective_list[idx], similarity_func=self.similarity_func, reg_lambda=self.args.param)

        self.EMBED_K_HASH_LIST = self.anc_embed_k_hash_list if self.args.ltype == 'npair' else self.embed_k_hash_l2_norm_list
        self.tree_idx_set = [
            tf.nn.top_k(v, k=self.args.d)[1] for v in self.EMBED_K_HASH_LIST
        ]  # k*[batch_size, d]
        self.tree_idx = tf.transpose(tf.stack(self.tree_idx_set, axis=0),
                                     [1, 0, 2])  # [batch_size, k, d]

        self.logger.info("Model building train hash ends")
示例#8
0
class Model(ModelPlugin):
    def __init__(self, train_dataset, val_dataset, test_dataset, logfilepath,
                 args):
        super().__init__(train_dataset, val_dataset, test_dataset, logfilepath,
                         args)

    def build(self):
        self.logger.info("Model building starts")
        tf.reset_default_graph()
        if self.args.ltype == 'npair':
            self.anc_img = tf.placeholder(tf.float32,
                                          shape=[
                                              self.args.nbatch // 2,
                                              self.height, self.width,
                                              self.nchannel
                                          ])
            self.pos_img = tf.placeholder(tf.float32,
                                          shape=[
                                              self.args.nbatch // 2,
                                              self.height, self.width,
                                              self.nchannel
                                          ])
            self.istrain = tf.placeholder(tf.bool, shape=[])
            self.label_list = [
                tf.placeholder(tf.int32, shape=[self.args.nbatch // 2])
                for idx in range(self.args.k)
            ]
        else:  # triplet
            self.img = tf.placeholder(tf.float32,
                                      shape=[
                                          self.args.nbatch, self.height,
                                          self.width, self.nchannel
                                      ])
            self.istrain = tf.placeholder(tf.bool, shape=[])
            self.label_list = [
                tf.placeholder(tf.int32, shape=[self.args.nbatch])
                for idx in range(self.args.k)
            ]
        self.generate_sess()

        self.conv_net = conv1_32
        if self.args.ltype == 'npair':
            self.anc_last = tf.nn.relu(
                self.conv_net(self.anc_img,
                              is_training=self.istrain,
                              reuse=False)[0])
            self.pos_last = tf.nn.relu(
                self.conv_net(self.pos_img,
                              is_training=self.istrain,
                              reuse=True)[0])
        else:  #triplet
            self.last = tf.nn.relu(
                self.conv_net(self.img, is_training=self.istrain,
                              reuse=False)[0])
        self.logger.info("Model building ends")

    def set_info(self, val_arg_sort, te_te_distance, te_tr_distance):
        self.logger.info("Model setting info starts")
        self.val_arg_sort = val_arg_sort
        self.te_te_distance = te_te_distance
        self.te_tr_distance = te_tr_distance
        self.logger.info("Model setting info ends")

    def build_hash(self):
        self.logger.info("Model building train hash starts")

        self.mcf = SolveMaxMatching(nworkers=self.args.nsclass,
                                    ntasks=self.args.d,
                                    k=1,
                                    pairwise_lamb=self.args.plamb2)

        with slim.arg_scope(
            [slim.fully_connected],
                activation_fn=None,
                weights_regularizer=slim.l2_regularizer(0.0005),
                biases_initializer=tf.zeros_initializer(),
                weights_initializer=tf.truncated_normal_initializer(0.0,
                                                                    0.01)):
            if self.args.ltype == 'triplet':
                # placeholder list
                self.objective_list = [
                    tf.placeholder(dtype=tf.float32,
                                   shape=[self.args.nbatch, self.args.d],
                                   name="objective%d" % i)
                    for i in range(self.args.k)
                ]
                self.embed_k_hash = self.last
                with tf.variable_scope('Hash', reuse=False):
                    self.embed_k_hash = slim.fully_connected(
                        self.embed_k_hash,
                        self.args.d * self.args.k,
                        scope="fc1")  # [batch, d*k]
                self.embed_k_hash_list = tf.split(
                    self.embed_k_hash, num_or_size_splits=self.args.k,
                    axis=1)  # list(k*[batch, d])
                self.embed_k_hash_l2_norm_list = [
                    tf.nn.l2_normalize(v, dim=-1)
                    for v in self.embed_k_hash_list
                ]  # list(k*[batch,d]), each l2 normalize
                self.pairwise_distance = pairwise_distance_w_obj1

                self.loss_hash = 0
                for idx in range(self.args.k):
                    self.loss_hash += triplet_semihard_loss_hash(
                        labels=self.label_list[idx],
                        embeddings=self.embed_k_hash_l2_norm_list[idx],
                        objectives=self.objective_list[idx],
                        pairwise_distance=self.pairwise_distance,
                        margin=self.args.param)
            else:
                self.objective_list = [
                    tf.placeholder(dtype=tf.float32,
                                   shape=[self.args.nbatch // 2, self.args.d],
                                   name="objective%d" % i)
                    for i in range(self.args.k)
                ]
                self.anc_embed_k_hash = self.anc_last
                self.pos_embed_k_hash = self.pos_last
                with tf.variable_scope('Hash', reuse=False):
                    self.anc_embed_k_hash = slim.fully_connected(
                        self.anc_embed_k_hash,
                        self.args.d * self.args.k,
                        scope="fc1")
                with tf.variable_scope('Hash', reuse=True):
                    self.pos_embed_k_hash = slim.fully_connected(
                        self.pos_embed_k_hash,
                        self.args.d * self.args.k,
                        scope="fc1")
                self.anc_embed_k_hash_list = tf.split(
                    self.anc_embed_k_hash,
                    num_or_size_splits=self.args.k,
                    axis=1)  # list(k*[batch, d])
                self.pos_embed_k_hash_list = tf.split(
                    self.pos_embed_k_hash,
                    num_or_size_splits=self.args.k,
                    axis=1)  # list(k*[batch, d])
                self.similarity_func = pairwise_similarity_w_obj1

                self.loss_hash = 0
                for idx in range(self.args.k):
                    self.loss_hash += npairs_loss_hash(labels=self.label_list[idx], embeddings_anchor=self.anc_embed_k_hash_list[idx], embeddings_positive=self.pos_embed_k_hash_list[idx],\
                                            objective=self.objective_list[idx], similarity_func=self.similarity_func, reg_lambda=self.args.param)

        self.EMBED_K_HASH_LIST = self.anc_embed_k_hash_list if self.args.ltype == 'npair' else self.embed_k_hash_l2_norm_list
        self.tree_idx_set = [
            tf.nn.top_k(v, k=self.args.d)[1] for v in self.EMBED_K_HASH_LIST
        ]  # k*[batch_size, d]
        self.tree_idx = tf.transpose(tf.stack(self.tree_idx_set, axis=0),
                                     [1, 0, 2])  # [batch_size, k, d]

        self.logger.info("Model building train hash ends")

    def set_up_train_hash(self):
        self.logger.info("Model setting up train hash starts")

        decay_func = DECAY_DICT[self.args.dtype]
        self.lr, update_step_op = decay_func(**DECAY_PARAMS_DICT[
            self.args.dtype][self.args.nbatch][self.args.dptype])

        update_ops = tf.get_collection("update_ops")
        var_slow_list, var_fast_list = list(), list()
        for var in tf.trainable_variables():
            if 'Hash' in var.name: var_fast_list.append(var)
            else: var_slow_list.append(var)

        with tf.control_dependencies(update_ops + [update_step_op]):
            self.train_op_hash = get_multi_train_op(
                tf.train.AdamOptimizer, self.loss_hash,
                [0.1 * self.lr, self.lr], [var_slow_list, var_fast_list])

        self.idx_convert = list()
        tmp = 1
        for i in range(self.args.k):
            self.idx_convert.append(tmp)
            tmp *= self.args.d
        self.idx_convert = np.array(self.idx_convert)[::-1]  # [d**(k-1),...,1]
        self.idx_convert = tf.constant(self.idx_convert,
                                       dtype=tf.int32)  # tensor [k]

        self.max_k_idx = tf.add(
                tf.reduce_sum(tf.multiply(self.tree_idx[:,:-1,0], self.idx_convert[:-1]), axis=-1, keep_dims=True),\
                self.tree_idx[:, -1, :self.args.sk]) # [batch_size, sk]

        if self.args.ltype == 'npair':
            assert get_shape(
                self.max_k_idx) == [self.args.nbatch // 2,
                                    self.args.sk], "Wrong max_k_idx shape"
        else:
            assert get_shape(
                self.max_k_idx) == [self.args.nbatch,
                                    self.args.sk], "Wrong max_k_idx shape"

        self.logger.info("Model setting up train hash ends")

    def run_batch_hash(self):
        if self.args.ltype == 'npair':
            batch_anc_img, batch_pos_img, batch_anc_label, batch_pos_label = self.dataset_dict[
                'train'].next_batch(batch_size=self.args.nbatch)
            feed_dict = {
                self.anc_img: batch_anc_img,
                self.pos_img: batch_pos_img,
                self.istrain: True
            }

            objective_list, plabel_list = list(), list()

            anc_unary_list, pos_unary_list = self.sess.run(
                [self.anc_embed_k_hash_list, self.pos_embed_k_hash_list],
                feed_dict=feed_dict)  # k*[nbatch//2, d]
            unary_list = [
                0.5 * (anc_unary_list[k_idx] + pos_unary_list[k_idx])
                for k_idx in range(self.args.k)
            ]  # k*[nbatch//2, d]
            unary_list = [
                np.mean(np.reshape(v, [self.args.nsclass, -1, self.args.d]),
                        axis=1) for v in unary_list
            ]  # k*[nsclass, d]

            for k_idx in range(self.args.k):
                unary = unary_list[k_idx]  # [nsclass, d]
                plabel = np.zeros(self.args.nsclass,
                                  dtype=np.int32)  # [nsclass]
                if k_idx != 0:
                    prev_plabel = plabel_list[-1]
                    prev_objective = objective_list[-1]
                    for i in range(self.args.nsclass):
                        plabel[i] = prev_plabel[i] * self.args.d + np.argmax(
                            prev_objective[i])

                objective = np.zeros([self.args.nsclass, self.args.d],
                                     dtype=np.float32)  # [nsclass, d]

                if k_idx == 0: results = self.mcf.solve(unary)
                elif k_idx == self.args.k - 1:
                    results = solve_maxmatching_soft_intraclass_multiselect(
                        array=unary,
                        k=self.args.sk,
                        labels=plabel,
                        plamb1=self.args.plamb1,
                        plamb2=self.args.plamb2)
                else:
                    results = solve_maxmatching_soft_intraclass_multiselect(
                        array=unary,
                        k=1,
                        labels=plabel,
                        plamb1=self.args.plamb1,
                        plamb2=self.args.plamb2)

                for i, j in results:
                    objective[i][j] = 1

                plabel_list.append(plabel)
                objective_list.append(objective)

            objective_list = [
                np.reshape(
                    np.transpose(
                        np.tile(
                            np.transpose(v, [1, 0]),
                            [self.args.nbatch // (2 * self.args.nsclass), 1]),
                        [1, 0]), [self.args.nbatch // 2, self.args.d])
                for v in objective_list
            ]  # k*[nsclass, d] => k*[batch_size//2, d]
            plabel_list = [
                np.reshape(
                    np.tile(np.expand_dims(v, axis=-1),
                            [1, self.args.nbatch // (2 * self.args.nsclass)]),
                    [-1]) for v in plabel_list
            ]  # k*[nsclass] => k*[batch_size//2]
            for k_idx in range(self.args.k):
                feed_dict[self.objective_list[k_idx]] = objective_list[k_idx]
                if k_idx == self.args.k - 1:
                    feed_dict[self.label_list[k_idx]] = bws2label(
                        objective=objective_list[k_idx], sparsity=self.args.sk
                    ) if self.args.label == 'dynamic' else batch_anc_label  # [nbatch//2]
                else:
                    feed_dict[self.label_list[k_idx]] = np.argmax(
                        objective_list[k_idx], axis=1
                    ) if self.args.label == 'dynamic' else batch_anc_label  # [nbatch//2]

            batch_loss_hash = self.sess.run(
                [self.train_op_hash, self.loss_hash], feed_dict=feed_dict)[1]
            return batch_loss_hash
        else:
            batch_img, batch_label = self.dataset_dict['train'].next_batch(
                batch_size=self.args.nbatch)

            feed_dict = {self.img: batch_img, self.istrain: True}
            objective_list, plabel_list = list(), list()

            unary_list = self.sess.run(self.embed_k_hash_l2_norm_list,
                                       feed_dict=feed_dict)  # k*[nbatch, d]
            unary_list = [
                np.mean(np.reshape(v, [self.args.nsclass, -1, self.args.d]),
                        axis=1) for v in unary_list
            ]  # k*[nsclass, d]
            for k_idx in range(self.args.k):
                unary = unary_list[k_idx]  # [nsclass, d]
                plabel = np.zeros(self.args.nsclass,
                                  dtype=np.int32)  # [nsclass]
                if k_idx != 0:
                    prev_plabel = plabel_list[-1]
                    prev_objective = objective_list[-1]
                    for i in range(self.args.nsclass):
                        plabel[i] = prev_plabel[i] * self.args.d + np.argmax(
                            prev_objective[i])

                objective = np.zeros([self.args.nsclass, self.args.d],
                                     dtype=np.float32)  # [nsclass, d]

                if k_idx == 0: results = self.mcf.solve(unary)
                elif k_idx == self.args.k - 1:
                    results = solve_maxmatching_soft_intraclass_multiselect(
                        array=unary,
                        k=self.args.sk,
                        labels=plabel,
                        plamb1=self.args.plamb1,
                        plamb2=self.args.plamb2)
                else:
                    results = solve_maxmatching_soft_intraclass_multiselect(
                        array=unary,
                        k=1,
                        labels=plabel,
                        plamb1=self.args.plamb1,
                        plamb2=self.args.plamb2)

                for i, j in results:
                    objective[i][j] = 1

                plabel_list.append(plabel)
                objective_list.append(objective)

            objective_list = [
                np.reshape(
                    np.transpose(
                        np.tile(np.transpose(v, [1, 0]),
                                [self.args.nbatch // self.args.nsclass, 1]),
                        [1, 0]), [self.args.nbatch, -1])
                for v in objective_list
            ]  # k*[nsclass, d] => k*[batch_size, d]
            plabel_list = [
                np.reshape(
                    np.tile(np.expand_dims(v, axis=-1),
                            [1, self.args.nbatch // self.args.nsclass]), [-1])
                for v in plabel_list
            ]  # k*[nsclass] => k*[batch_size]

            for k_idx in range(self.args.k):
                feed_dict[self.objective_list[k_idx]] = objective_list[k_idx]

                if k_idx == self.args.k - 1:
                    feed_dict[self.label_list[k_idx]] = bws2label(
                        objective=objective_list[k_idx], sparsity=self.args.sk
                    ) if self.args.label == 'dynamic' else batch_label  # [nbatch]
                else:
                    feed_dict[self.label_list[k_idx]] = np.argmax(
                        objective_list[k_idx], axis=1
                    ) if self.args.label == 'dynamic' else batch_label  # [nbatch]

            batch_loss_hash = self.sess.run(
                [self.train_op_hash, self.loss_hash], feed_dict=feed_dict)[1]
            return batch_loss_hash

    def train_hash(self, epoch, save_dir, board_dir):
        self.logger.info("Model training starts")
        self.writer = SummaryWriter(board_dir)
        self.logger.info("initial_lr : {}".format(self.sess.run(self.lr)))
        if self.args.ltype == 'npair':

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.anc_img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch // 2,
                                   dim=4,
                                   train_gate=self.istrain)
        else:  # triplet

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch,
                                   dim=4,
                                   train_gate=self.istrain)

        def eval_on_val():
            val_max_k_idx = custom_apply_tf_op(
                inputs=self.val_image,
                output_gate=self.max_k_idx)  # [nval, sk]
            val_nmi, val_suf = get_nmi_suf_quick(
                index_array=val_max_k_idx,
                label_array=self.val_label,
                ncluster=self.args.d**self.args.k,
                nlabel=self.ncls_val)
            nsuccess = 0
            for i in range(self.nval):
                for j in self.val_arg_sort[i]:
                    if i == j: continue
                    if len(set(val_max_k_idx[j]) & set(val_max_k_idx[i])) > 0:
                        if self.val_label[i] == self.val_label[j]:
                            nsuccess += 1
                        break
            val_p1 = nsuccess / self.nval
            return val_suf, val_nmi, val_p1

        val_suf, val_nmi, val_p1 = eval_on_val()
        max_val_p1 = val_p1
        self.logger.info(
            "Initial val_suf = {} val_nmi = {} val_p1 = {}".format(
                val_suf, val_nmi, val_p1))
        self.save(0, save_dir)
        for epoch_ in range(epoch):
            train_epoch_loss = 0
            for _ in range(self.nbatch_train):
                batch_loss = self.run_batch_hash()
                train_epoch_loss += batch_loss
            train_epoch_loss /= self.nbatch_train
            val_suf, val_nmi, val_p1 = eval_on_val()

            self.logger.info(
                "Epoch({}/{}) train loss = {} val suf = {} val nmi = {} val p1 = {}"
                .format(epoch_ + 1, epoch, train_epoch_loss, val_suf, val_nmi,
                        val_p1))
            self.writer.add_summaries(
                {
                    "loss": train_epoch_loss,
                    "lr": self.sess.run(self.lr),
                    "suf": val_suf,
                    "nmi": val_nmi,
                    "p1": val_p1
                }, epoch_ + 1)
            if max_val_p1 < val_p1:
                max_val_p1 = val_p1
                self.save(epoch_ + 1, save_dir)
        self.logger.info("Model training ends")

    def prepare_test_hash(self):
        self.logger.info("Model preparing test")
        if self.args.ltype == 'npair':

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.anc_img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch // 2,
                                   dim=4,
                                   train_gate=self.istrain)
        else:  # triplet

            def custom_apply_tf_op(inputs, output_gate):
                return apply_tf_op(inputs=inputs,
                                   session=self.sess,
                                   input_gate=self.img,
                                   output_gate=output_gate,
                                   batch_size=self.args.nbatch,
                                   dim=4,
                                   train_gate=self.istrain)

        self.test_max_k_idx = custom_apply_tf_op(
            inputs=self.test_image, output_gate=self.max_k_idx)  # [ntest, sk]
        self.test_tree_idx = custom_apply_tf_op(
            inputs=self.test_image, output_gate=self.tree_idx)  # [ntest, k, d]

        self.train_max_k_idx = custom_apply_tf_op(
            inputs=self.train_image,
            output_gate=self.max_k_idx)  # [ntrain, sk]
        self.train_tree_idx = custom_apply_tf_op(
            inputs=self.train_image,
            output_gate=self.tree_idx)  # [ntrain, k, d]

    def test_hash_metric(self, k_set):
        self.logger.info("Model testing k hash starts")

        performance = evaluate_hashtree_te_tr_sparsity(te_tr_distance=self.te_tr_distance, te_te_distance=self.te_te_distance,\
                                                       train_tree_idx=self.train_tree_idx, test_tree_idx=self.test_tree_idx,\
                                                       train_max_k_idx=self.train_max_k_idx, test_max_k_idx=self.test_max_k_idx,\
                                                       train_label=self.train_label, test_label=self.test_label,\
                                                       ncls_train=self.ncls_train, ncls_test=self.ncls_test, k_set=k_set, logger=self.logger)

        self.logger.info("Model testing k hash ends")
        return performance