def get_predict_v(self, x, score): # [N] item_pop = tf.cast(self.data.item_pop, tf.float32) item_pop_log = tf.log(item_pop + np.e) item_deg_self = tf.cast(tf.gather(self.data.item_deg_self_per_phase, x.phase), tf.float32) item_pop_self_log = tf.log(item_deg_self + np.e) if args.mode_pop == 'log': score = score * args.alpha_pop_base + score / item_pop_log elif args.mode_pop == 'log_mdeg': score = score * item_pop_self_log / item_pop_log elif args.mode_pop == 'log_mdeg_only': score = score * item_pop_self_log elif args.mode_pop == 'linear': item_pop = tf.cast(self.data.item_pop, tf.float32) + 1.0 score = score * args.alpha_pop_base + score / item_pop elif args.mode_pop == 'log_md': item_pop_self_log = tf.log(item_deg_self + np.e + 10.0) score = score * item_pop_self_log / item_pop_log if args.mode_rare in {'log', 'linear', 'base'}: if args.mode_rare == 'log': rare_weight_pop = 1.0 / tf.log(tf.cast(self.data.item_pop, tf.float32) + np.e) elif args.mode_rare == 'linear': rare_weight_pop = 1.0 / tf.cast(self.data.item_pop, tf.float32) elif args.mode_rare == 'base': rare_weight_pop = 0.0 else: raise Exception # [N] rare_weight = float(args.alpha_rare) rare_weight = rare_weight + rare_weight_pop rare_weight *= float(args.alpha_rare_mul) is_rare = self.is_rare(x) score = tf.where(is_rare, score * rare_weight + float(args.alpha_rare_base), score) score = UTILS.mask_logits(score, x.score_mask) tf.summary.histogram('score', score) top_items, top_scores = self.topk_idx(score, x) if args.dump_all: self.tmp_vars.update(all_scores=score, item_seq=x.seq, ts_seq=x.ts, q_ts=x.q_ts) ret = Object(user=x.user, phase=x.phase, top_items=top_items, top_scores=top_scores) ret.update(**self.tmp_vars) return ret return Object(user=x.user, phase=x.phase, top_items=top_items, top_scores=top_scores)
class Base: args = Object() need_train = True def __init__(self, data: dataset.Data): self.tmp_vars = Object() self.data = data # self.save_dir = f'{utils.save_dir}/{args.run_name}' self.save_dir = f'{utils.save_dir}/{args.msg}' with self.data.tf_graph.as_default(): tf.set_random_seed(args.seed) self.compile() self.fit_step = 0 def compile(self): self.emb_l2_norm_op = None self.sess = UTILS.get_session() self.make_io() self.make_model() self.sess.run(tf.global_variables_initializer()) if self.emb_l2_norm_op is not None: self.sess.run(self.emb_l2_norm_op) def make_io(self): self.is_on_train = tf.placeholder(tf.bool, [], 'is_on_train') train_data = self.data.train_batch_repeat train_data_iter = train_data.make_one_shot_iterator() self.train_data_handle = self.sess.run(train_data_iter.string_handle()) self.data_handle = tf.placeholder(tf.string, [], 'data_handle') data_iter = tf.data.Iterator.from_string_handle( self.data_handle, train_data.output_types, train_data.output_shapes, ) self.input_dict = data_iter.get_next() self.input_dict = Object(**self.input_dict) def get_metric_v(self, x, predict_v): # [BS,] true_item = x.ans true_item_a1 = tf.expand_dims(true_item, -1) # [BS, M], [BS, 1] eq = tf.cast(tf.equal(predict_v.top_items, true_item_a1), tf.int32) # [BS,] m = tf.reduce_max(eq, -1) idx = tf.cast(tf.argmax(eq, -1), tf.int32) rank = idx + m - 1 ndcg = tf.log(2.0) * tf.cast(m, tf.float32) / tf.log(2.0 + tf.cast(idx, tf.float32)) hit_rate = tf.cast(m, tf.float32) ret = Object( ndcg=ndcg, hit_rate=hit_rate, user=x.user, true_item=true_item, phase=x.phase, top_items=predict_v.top_items, top_scores=predict_v.top_scores, rank=rank, q_ts=x.q_ts, ) return ret def get_predict_v(self, x, score): # [N] item_pop = tf.cast(self.data.item_pop, tf.float32) item_pop_log = tf.log(item_pop + np.e) item_deg_self = tf.cast(tf.gather(self.data.item_deg_self_per_phase, x.phase), tf.float32) item_pop_self_log = tf.log(item_deg_self + np.e) if args.mode_pop == 'log': score = score * args.alpha_pop_base + score / item_pop_log elif args.mode_pop == 'log_mdeg': score = score * item_pop_self_log / item_pop_log elif args.mode_pop == 'log_mdeg_only': score = score * item_pop_self_log elif args.mode_pop == 'linear': item_pop = tf.cast(self.data.item_pop, tf.float32) + 1.0 score = score * args.alpha_pop_base + score / item_pop elif args.mode_pop == 'log_md': item_pop_self_log = tf.log(item_deg_self + np.e + 10.0) score = score * item_pop_self_log / item_pop_log if args.mode_rare in {'log', 'linear', 'base'}: if args.mode_rare == 'log': rare_weight_pop = 1.0 / tf.log(tf.cast(self.data.item_pop, tf.float32) + np.e) elif args.mode_rare == 'linear': rare_weight_pop = 1.0 / tf.cast(self.data.item_pop, tf.float32) elif args.mode_rare == 'base': rare_weight_pop = 0.0 else: raise Exception # [N] rare_weight = float(args.alpha_rare) rare_weight = rare_weight + rare_weight_pop rare_weight *= float(args.alpha_rare_mul) is_rare = self.is_rare(x) score = tf.where(is_rare, score * rare_weight + float(args.alpha_rare_base), score) score = UTILS.mask_logits(score, x.score_mask) tf.summary.histogram('score', score) top_items, top_scores = self.topk_idx(score, x) if args.dump_all: self.tmp_vars.update(all_scores=score, item_seq=x.seq, ts_seq=x.ts, q_ts=x.q_ts) ret = Object(user=x.user, phase=x.phase, top_items=top_items, top_scores=top_scores) ret.update(**self.tmp_vars) return ret return Object(user=x.user, phase=x.phase, top_items=top_items, top_scores=top_scores) def make_model(self): with tf.variable_scope('Network', reuse=tf.AUTO_REUSE, regularizer=UTILS.l2_loss('all')): x = self.input_dict self.train_op, self.train_v, self.predict_v = self.forward(x) self.metric_v = self.get_metric_v(x, self.predict_v) self.metric_v.update(loss=self.train_v.loss) network_var_list = tf.trainable_variables(scope='^Network/') if network_var_list: args.log.log('trainable_variables:') for v in network_var_list: args.log.log(f'network: {v}') tf.summary.histogram(v.name, v) self.saver = tf.train.Saver(var_list=tf.trainable_variables()) # self.saver_emb = tf.train.Saver(var_list=tf.trainable_variables(scope='^Network/Emb_')) def fit(self): data = { self.is_on_train: True, self.data_handle: self.train_data_handle, } tb_v = [] if args.run_tb: tb_v = [self.all_summary] debug_v = DEBUG.fit_show_list all_v = [self.train_op, self.train_v, debug_v, tb_v] _, train_v, debug_v, tb_v = self.sess.run(all_v, data) if self.emb_l2_norm_op is not None: self.sess.run(self.emb_l2_norm_op) if tb_v: self.tbfw.add_summary(tb_v[0], self.fit_step) DEBUG.when_run(debug_v) self.fit_step += 1 return train_v def inference(self, data, out_obj): with self.data.tf_graph.as_default(): data_iter = data.make_one_shot_iterator() data_handle = self.sess.run(data_iter.string_handle()) data = { self.is_on_train: False, self.data_handle: data_handle, } while True: try: ret_value, debug_v = self.sess.run([out_obj, DEBUG.inf_show_list], data) DEBUG.when_run(debug_v) yield ret_value except tf.errors.OutOfRangeError: break def metric(self, data): for v in self.inference(data, self.metric_v): yield v def predict(self, data): for v in self.inference(data, self.predict_v): yield v def save(self, s): if not self.need_train: return name = f'{self.save_dir}/model_{s}.ckpt' self.saver.save(self.sess, name) def restore(self, s): if not self.need_train: return name = f'{self.save_dir}/model_{s}.ckpt' self.saver.restore(self.sess, name) def restore_from_other(self, run_name): save_dir = f'{utils.save_dir}/{run_name}' s = 0 if not self.need_train: return import os if not os.path.isdir(save_dir): args.log.log('download from hdfs') sh = f'$HADOOP_HOME/bin/hadoop fs -get save/{utils.project_name}/{run_name} {utils.save_dir}/' print(os.system(sh)) name = f'{save_dir}/model_{s}.ckpt' self.saver.restore(self.sess, name) # if args.restore_train: # self.saver_emb.restore(self.sess, name) # else: # self.saver.restore(self.sess, name) def forward(self, x): raise NotImplementedError def is_rare(self, x): is_rare = tf.gather(self.data.is_rare_per_phase, x.phase) return is_rare def topk_idx(self, prob, x): rare_k = args.nb_rare_k if rare_k < 0: topk = tf.nn.top_k(prob, args.nb_topk) return topk.indices, topk.values is_rare = self.is_rare(x) prob_rare = UTILS.mask_logits(prob, is_rare) prob_rich = UTILS.mask_logits(prob, tf.logical_not(is_rare)) topk_rare = tf.nn.top_k(prob_rare, rare_k).indices topk_rich = tf.nn.top_k(prob_rich, args.nb_topk - rare_k).indices topk = tf.concat([topk_rich, topk_rare], -1) # [BS, N], [BS, L] --> [BS, L] top_prob = tf.batch_gather(prob, topk) sort_topk = tf.nn.top_k(top_prob, args.nb_topk) sort_idx = sort_topk.indices top_values = sort_topk.values sorted_topk = tf.batch_gather(topk, sort_idx) return sorted_topk, top_values def before_train(self): pass def after_train(self): pass