def get_predict_v(self, x, score):
        # [N]
        item_pop = tf.cast(self.data.item_pop, tf.float32)
        item_pop_log = tf.log(item_pop + np.e)

        item_deg_self = tf.cast(tf.gather(self.data.item_deg_self_per_phase, x.phase), tf.float32)
        item_pop_self_log = tf.log(item_deg_self + np.e)

        if args.mode_pop == 'log':
            score = score * args.alpha_pop_base + score / item_pop_log
        elif args.mode_pop == 'log_mdeg':
            score = score * item_pop_self_log / item_pop_log
        elif args.mode_pop == 'log_mdeg_only':
            score = score * item_pop_self_log
        elif args.mode_pop == 'linear':
            item_pop = tf.cast(self.data.item_pop, tf.float32) + 1.0
            score = score * args.alpha_pop_base + score / item_pop
        elif args.mode_pop == 'log_md':
            item_pop_self_log = tf.log(item_deg_self + np.e + 10.0)
            score = score * item_pop_self_log / item_pop_log

        if args.mode_rare in {'log', 'linear', 'base'}:
            if args.mode_rare == 'log':
                rare_weight_pop = 1.0 / tf.log(tf.cast(self.data.item_pop, tf.float32) + np.e)
            elif args.mode_rare == 'linear':
                rare_weight_pop = 1.0 / tf.cast(self.data.item_pop, tf.float32)
            elif args.mode_rare == 'base':
                rare_weight_pop = 0.0
            else:
                raise Exception
            # [N]
            rare_weight = float(args.alpha_rare)
            rare_weight = rare_weight + rare_weight_pop
            rare_weight *= float(args.alpha_rare_mul)

            is_rare = self.is_rare(x)
            score = tf.where(is_rare, score * rare_weight + float(args.alpha_rare_base), score)

        score = UTILS.mask_logits(score, x.score_mask)

        tf.summary.histogram('score', score)
        top_items, top_scores = self.topk_idx(score, x)

        if args.dump_all:
            self.tmp_vars.update(all_scores=score, item_seq=x.seq, ts_seq=x.ts, q_ts=x.q_ts)
            ret = Object(user=x.user, phase=x.phase, top_items=top_items, top_scores=top_scores)
            ret.update(**self.tmp_vars)
            return ret

        return Object(user=x.user, phase=x.phase, top_items=top_items, top_scores=top_scores)
class Base:
    args = Object()
    need_train = True
    def __init__(self, data: dataset.Data):
        self.tmp_vars = Object()

        self.data = data

        # self.save_dir = f'{utils.save_dir}/{args.run_name}'
        self.save_dir = f'{utils.save_dir}/{args.msg}'

        with self.data.tf_graph.as_default():
            tf.set_random_seed(args.seed)
            self.compile()
        self.fit_step = 0

    def compile(self):
        self.emb_l2_norm_op = None
        self.sess = UTILS.get_session()
        self.make_io()
        self.make_model()
        self.sess.run(tf.global_variables_initializer())
        if self.emb_l2_norm_op is not None:
            self.sess.run(self.emb_l2_norm_op)

    def make_io(self):
        self.is_on_train = tf.placeholder(tf.bool, [], 'is_on_train')

        train_data = self.data.train_batch_repeat
        train_data_iter = train_data.make_one_shot_iterator()
        self.train_data_handle = self.sess.run(train_data_iter.string_handle())
        self.data_handle = tf.placeholder(tf.string, [], 'data_handle')
        data_iter = tf.data.Iterator.from_string_handle(
            self.data_handle,
            train_data.output_types,
            train_data.output_shapes,
        )
        self.input_dict = data_iter.get_next()
        self.input_dict = Object(**self.input_dict)

    def get_metric_v(self, x, predict_v):
        # [BS,]
        true_item = x.ans
        true_item_a1 = tf.expand_dims(true_item, -1)
        # [BS, M], [BS, 1]
        eq = tf.cast(tf.equal(predict_v.top_items, true_item_a1), tf.int32)
        # [BS,]
        m = tf.reduce_max(eq, -1)
        idx = tf.cast(tf.argmax(eq, -1), tf.int32)
        rank = idx + m - 1
        ndcg = tf.log(2.0) * tf.cast(m, tf.float32) / tf.log(2.0 + tf.cast(idx, tf.float32))
        hit_rate = tf.cast(m, tf.float32)
        ret = Object(
            ndcg=ndcg,
            hit_rate=hit_rate,
            user=x.user,
            true_item=true_item,
            phase=x.phase,
            top_items=predict_v.top_items,
            top_scores=predict_v.top_scores,
            rank=rank,
            q_ts=x.q_ts,
        )
        return ret

    def get_predict_v(self, x, score):
        # [N]
        item_pop = tf.cast(self.data.item_pop, tf.float32)
        item_pop_log = tf.log(item_pop + np.e)

        item_deg_self = tf.cast(tf.gather(self.data.item_deg_self_per_phase, x.phase), tf.float32)
        item_pop_self_log = tf.log(item_deg_self + np.e)

        if args.mode_pop == 'log':
            score = score * args.alpha_pop_base + score / item_pop_log
        elif args.mode_pop == 'log_mdeg':
            score = score * item_pop_self_log / item_pop_log
        elif args.mode_pop == 'log_mdeg_only':
            score = score * item_pop_self_log
        elif args.mode_pop == 'linear':
            item_pop = tf.cast(self.data.item_pop, tf.float32) + 1.0
            score = score * args.alpha_pop_base + score / item_pop
        elif args.mode_pop == 'log_md':
            item_pop_self_log = tf.log(item_deg_self + np.e + 10.0)
            score = score * item_pop_self_log / item_pop_log

        if args.mode_rare in {'log', 'linear', 'base'}:
            if args.mode_rare == 'log':
                rare_weight_pop = 1.0 / tf.log(tf.cast(self.data.item_pop, tf.float32) + np.e)
            elif args.mode_rare == 'linear':
                rare_weight_pop = 1.0 / tf.cast(self.data.item_pop, tf.float32)
            elif args.mode_rare == 'base':
                rare_weight_pop = 0.0
            else:
                raise Exception
            # [N]
            rare_weight = float(args.alpha_rare)
            rare_weight = rare_weight + rare_weight_pop
            rare_weight *= float(args.alpha_rare_mul)

            is_rare = self.is_rare(x)
            score = tf.where(is_rare, score * rare_weight + float(args.alpha_rare_base), score)

        score = UTILS.mask_logits(score, x.score_mask)

        tf.summary.histogram('score', score)
        top_items, top_scores = self.topk_idx(score, x)

        if args.dump_all:
            self.tmp_vars.update(all_scores=score, item_seq=x.seq, ts_seq=x.ts, q_ts=x.q_ts)
            ret = Object(user=x.user, phase=x.phase, top_items=top_items, top_scores=top_scores)
            ret.update(**self.tmp_vars)
            return ret

        return Object(user=x.user, phase=x.phase, top_items=top_items, top_scores=top_scores)

    def make_model(self):
        with tf.variable_scope('Network', reuse=tf.AUTO_REUSE, regularizer=UTILS.l2_loss('all')):
            x = self.input_dict
            self.train_op, self.train_v, self.predict_v = self.forward(x)
            self.metric_v = self.get_metric_v(x, self.predict_v)
            self.metric_v.update(loss=self.train_v.loss)

        network_var_list = tf.trainable_variables(scope='^Network/')
        if network_var_list:
            args.log.log('trainable_variables:')
            for v in network_var_list:
                args.log.log(f'network: {v}')
                tf.summary.histogram(v.name, v)
            self.saver = tf.train.Saver(var_list=tf.trainable_variables())
            # self.saver_emb = tf.train.Saver(var_list=tf.trainable_variables(scope='^Network/Emb_'))

    def fit(self):
        data = {
            self.is_on_train: True,
            self.data_handle: self.train_data_handle,
        }
        tb_v = []
        if args.run_tb:
            tb_v = [self.all_summary]
        debug_v = DEBUG.fit_show_list
        all_v = [self.train_op, self.train_v, debug_v, tb_v]
        _, train_v, debug_v, tb_v = self.sess.run(all_v, data)
        if self.emb_l2_norm_op is not None:
            self.sess.run(self.emb_l2_norm_op)
        if tb_v:
            self.tbfw.add_summary(tb_v[0], self.fit_step)
        DEBUG.when_run(debug_v)
        self.fit_step += 1
        return train_v

    def inference(self, data, out_obj):
        with self.data.tf_graph.as_default():
            data_iter = data.make_one_shot_iterator()
            data_handle = self.sess.run(data_iter.string_handle())
            data = {
                self.is_on_train: False,
                self.data_handle: data_handle,
            }
        while True:
            try:
                ret_value, debug_v = self.sess.run([out_obj, DEBUG.inf_show_list], data)
                DEBUG.when_run(debug_v)
                yield ret_value
            except tf.errors.OutOfRangeError:
                break

    def metric(self, data):
        for v in self.inference(data, self.metric_v):
            yield v

    def predict(self, data):
        for v in self.inference(data, self.predict_v):
            yield v

    def save(self, s):
        if not self.need_train:
            return
        name = f'{self.save_dir}/model_{s}.ckpt'
        self.saver.save(self.sess, name)
    def restore(self, s):
        if not self.need_train:
            return
        name = f'{self.save_dir}/model_{s}.ckpt'
        self.saver.restore(self.sess, name)
    def restore_from_other(self, run_name):
        save_dir = f'{utils.save_dir}/{run_name}'
        s = 0
        if not self.need_train:
            return
        import os
        if not os.path.isdir(save_dir):
            args.log.log('download from hdfs')
            sh = f'$HADOOP_HOME/bin/hadoop fs -get save/{utils.project_name}/{run_name} {utils.save_dir}/'
            print(os.system(sh))
        name = f'{save_dir}/model_{s}.ckpt'
        self.saver.restore(self.sess, name)
        # if args.restore_train:
        #     self.saver_emb.restore(self.sess, name)
        # else:
        #     self.saver.restore(self.sess, name)


    def forward(self, x):
        raise NotImplementedError

    def is_rare(self, x):
        is_rare = tf.gather(self.data.is_rare_per_phase, x.phase)
        return is_rare

    def topk_idx(self, prob, x):
        rare_k = args.nb_rare_k
        if rare_k < 0:
            topk = tf.nn.top_k(prob, args.nb_topk)
            return topk.indices, topk.values

        is_rare = self.is_rare(x)
        prob_rare = UTILS.mask_logits(prob, is_rare)
        prob_rich = UTILS.mask_logits(prob, tf.logical_not(is_rare))

        topk_rare = tf.nn.top_k(prob_rare, rare_k).indices
        topk_rich = tf.nn.top_k(prob_rich, args.nb_topk - rare_k).indices
        topk = tf.concat([topk_rich, topk_rare], -1)
        # [BS, N], [BS, L] --> [BS, L]
        top_prob = tf.batch_gather(prob, topk)
        sort_topk = tf.nn.top_k(top_prob, args.nb_topk)
        sort_idx = sort_topk.indices
        top_values = sort_topk.values

        sorted_topk = tf.batch_gather(topk, sort_idx)
        return sorted_topk, top_values

    def before_train(self):
        pass

    def after_train(self):
        pass