def __init__(self, n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE, BATCH_SIZE, MEMORY_SIZE, SEQ_LEN=400, Mem_Induction=0, Util_Reg=0, use_negsample=False, mask_flag=False): super(Model_MIMN, self).__init__(n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE, BATCH_SIZE, SEQ_LEN, use_negsample, Flag="MIMN") self.reg = Util_Reg def clear_mask_state(state, begin_state, begin_channel_rnn_state, mask, cell, t): state["controller_state"] = ( 1 - tf.reshape(mask[:, t], (BATCH_SIZE, 1)) ) * begin_state["controller_state"] + tf.reshape( mask[:, t], (BATCH_SIZE, 1)) * state["controller_state"] state["M"] = (1 - tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1))) * begin_state["M"] + tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1)) * state["M"] state["key_M"] = (1 - tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1))) * begin_state["key_M"] + tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1)) * state["key_M"] state["sum_aggre"] = (1 - tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1))) * begin_state["sum_aggre"] + tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1)) * state["sum_aggre"] if Mem_Induction > 0: temp_channel_rnn_state = [] for i in range(MEMORY_SIZE): temp_channel_rnn_state.append( cell.channel_rnn_state[i] * tf.expand_dims(mask[:, t], axis=1) + begin_channel_rnn_state[i] * (1 - tf.expand_dims(mask[:, t], axis=1))) cell.channel_rnn_state = temp_channel_rnn_state temp_channel_rnn_output = [] for i in range(MEMORY_SIZE): temp_output = cell.channel_rnn_output[i] * tf.expand_dims( mask[:, t], axis=1) + begin_channel_rnn_output[i] * ( 1 - tf.expand_dims(self.mask[:, t], axis=1)) temp_channel_rnn_output.append(temp_output) cell.channel_rnn_output = temp_channel_rnn_output return state cell = mimn.MIMNCell(controller_units=HIDDEN_SIZE, memory_size=MEMORY_SIZE, memory_vector_dim=2 * EMBEDDING_DIM, read_head_num=1, write_head_num=1, reuse=False, output_dim=HIDDEN_SIZE, clip_value=20, batch_size=BATCH_SIZE, mem_induction=Mem_Induction, util_reg=Util_Reg) state = cell.zero_state(BATCH_SIZE, tf.float32) if Mem_Induction > 0: begin_channel_rnn_output = cell.channel_rnn_output else: begin_channel_rnn_output = 0.0 begin_state = state self.state_list = [state] self.mimn_o = [] for t in range(SEQ_LEN): output, state, temp_output_list = cell(self.item_his_eb[:, t, :], state) if mask_flag: state = clear_mask_state(state, begin_state, begin_channel_rnn_output, self.mask, cell, t) self.mimn_o.append(output) self.state_list.append(state) self.mimn_o = tf.stack(self.mimn_o, axis=1) self.state_list.append(state) mean_memory = tf.reduce_mean(state['sum_aggre'], axis=-2) before_aggre = state['w_aggre'] read_out, _, _ = cell(self.item_eb, state) if use_negsample: aux_loss_1 = self.auxiliary_loss(self.mimn_o[:, :-1, :], self.item_his_eb[:, 1:, :], self.neg_his_eb[:, 1:, :], self.mask[:, 1:], stag="bigru_0") self.aux_loss = aux_loss_1 if self.reg: self.reg_loss = cell.capacity_loss(before_aggre) else: self.reg_loss = tf.zeros(1) if Mem_Induction == 1: channel_memory_tensor = tf.concat(temp_output_list, 1) multi_channel_hist = din_attention(self.item_eb, channel_memory_tensor, HIDDEN_SIZE, None, stag='pal') inp = tf.concat([ self.item_eb, self.item_his_eb_sum, read_out, tf.squeeze(multi_channel_hist), mean_memory * self.item_eb ], 1) else: inp = tf.concat([ self.item_eb, self.item_his_eb_sum, read_out, mean_memory * self.item_eb ], 1) self.build_fcn_net(inp, use_dice=False)
def __init__(self, uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN=400, Mem_Induction=0, Util_Reg=0, use_negsample=False, mask_flag=False, args=None): super(Model_MIMN, self).__init__(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN, use_negsample, Flag="MIMN", args=args) logging.info(locals()) self.reg = args.util_reg seq_reduce = args.mimn_seq_reduce Mem_Induction = args.mem_induction MEMORY_SIZE = args.memory_size def clear_mask_state(state, begin_state, begin_channel_rnn_state, mask, cell, t): state["controller_state"] = ( 1 - tf.reshape(mask[:, t], (BATCH_SIZE, 1)) ) * begin_state["controller_state"] + tf.reshape( mask[:, t], (BATCH_SIZE, 1)) * state["controller_state"] state["M"] = (1 - tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1))) * begin_state["M"] + tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1)) * state["M"] state["key_M"] = (1 - tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1))) * begin_state["key_M"] + tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1)) * state["key_M"] state["sum_aggre"] = (1 - tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1))) * begin_state["sum_aggre"] + tf.reshape( mask[:, t], (BATCH_SIZE, 1, 1)) * state["sum_aggre"] if Mem_Induction > 0: temp_channel_rnn_state = [] for i in range(MEMORY_SIZE): temp_channel_rnn_state.append( cell.channel_rnn_state[i] * tf.expand_dims(mask[:, t], axis=1) + begin_channel_rnn_state[i] * (1 - tf.expand_dims(mask[:, t], axis=1))) cell.channel_rnn_state = temp_channel_rnn_state temp_channel_rnn_output = [] for i in range(MEMORY_SIZE): temp_output = cell.channel_rnn_output[i] * tf.expand_dims(mask[:, t], axis=1) + \ begin_channel_rnn_output[i] * (1 - tf.expand_dims(self.mask[:, t], axis=1)) temp_channel_rnn_output.append(temp_output) cell.channel_rnn_output = temp_channel_rnn_output return state inputs = self.inputs mimn_seq_split = args.mimn_seq_split if args.mimn_seq_split else args.long_seq_split if args and mimn_seq_split: seq_split = [(int(x.split(":")[0]), int(x.split(":")[1])) for x in mimn_seq_split.split(",")] for idx, (left_idx, right_idx) in enumerate(seq_split): SEQ_LEN = abs(right_idx - left_idx) with tf.name_scope('MIMN_Layer_{0}'.format(idx)): logging.info("mimn_layer {0}:{1}".format( left_idx, right_idx)) mask = self.mask[:, left_idx:right_idx] item_his_eb = self.item_his_eb[:, left_idx:right_idx] * mask[:, :, None] if self.use_negsample: neg_his_eb = self.neg_his_eb[:, left_idx: right_idx] * mask[:, :, None] item_eb = self.item_eb if args.mimn_update_emb == 0: item_his_eb = tf.stop_gradient(item_his_eb) item_eb = tf.stop_gradient(item_eb) if self.use_negsample: neg_his_eb = tf.stop_gradient(neg_his_eb) memory_vector_dim = self.item_his_eb.get_shape().as_list()[-1] head_num = 1 if args.head_num: head_num = args.head_num cell = mimn.MIMNCell(controller_units=HIDDEN_SIZE, memory_size=MEMORY_SIZE, memory_vector_dim=memory_vector_dim, read_head_num=head_num, write_head_num=head_num, reuse=False, output_dim=HIDDEN_SIZE, clip_value=100, batch_size=BATCH_SIZE, mem_induction=Mem_Induction, util_reg=Util_Reg) state = cell.zero_state(BATCH_SIZE, tf.float32) if Mem_Induction > 0: begin_channel_rnn_output = cell.channel_rnn_output else: begin_channel_rnn_output = 0.0 begin_state = state self.state_list = [state] self.mimn_o = [] if args.mimn_seq_reduce: logging.info("mimn_seq_reduce:{0}".format( args.mimn_seq_reduce)) seq_reduce = args.mimn_seq_reduce SEQ_LEN = int(SEQ_LEN / seq_reduce) dim = item_his_eb.get_shape().as_list()[-1] logging.info(dim) item_his_eb = tf.reshape(item_his_eb, [-1, SEQ_LEN, seq_reduce, dim]) item_his_eb = tf.reduce_mean(item_his_eb, axis=-2) logging.info(item_his_eb.get_shape()) mask = tf.reshape(mask, [-1, SEQ_LEN, seq_reduce]) mask = tf.reduce_sum(mask, axis=-1) mask = tf.cast(tf.not_equal(mask, tf.zeros_like(mask)), tf.float32) for t in range(SEQ_LEN): output, state, temp_output_list = cell( item_his_eb[:, t, :], state) if mask_flag: state = clear_mask_state(state, begin_state, begin_channel_rnn_output, mask, cell, t) self.mimn_o.append(output) self.state_list.append(state) self.mimn_o = tf.stack(self.mimn_o, axis=1) self.state_list.append(state) mean_memory = tf.reduce_mean(state['sum_aggre'], axis=-2) before_aggre = state['w_aggre'] read_out, _, _ = cell(item_eb, state) if use_negsample: aux_loss_1 = self.auxiliary_loss(self.mimn_o[:, :-1, :], item_his_eb[:, 1:, :], neg_his_eb[:, 1:, :], mask[:, 1:], stag="bigru_0") self.aux_loss = aux_loss_1 if self.reg: self.reg_loss = cell.capacity_loss(before_aggre) else: self.reg_loss = tf.zeros(1) if Mem_Induction == 1: channel_memory_tensor = tf.concat(temp_output_list, 1) multi_channel_hist = din_attention(item_eb, channel_memory_tensor, HIDDEN_SIZE, None, stag='pal', att_func=args.att_func) inputs += [ read_out, tf.squeeze(multi_channel_hist), mean_memory * item_eb ] else: inputs += [read_out, mean_memory * item_eb] if args and args.long_seq_split: seq_split = [(int(x.split(":")[0]), int(x.split(":")[1])) for x in args.long_seq_split.split(",")] for idx, (left_idx, right_idx) in enumerate(seq_split): mask = self.mask[:, left_idx:right_idx] item_his_sum_emb = tf.reduce_sum( self.item_his_eb[:, left_idx:right_idx] * mask[:, :, None], 1) / (tf.reduce_sum(mask, 1, keepdims=True) + 1.0) inputs.append(item_his_sum_emb) inp = tf.concat(inputs, 1) self.build_fcn_net(inp, use_dice=False)