Ejemplo n.º 1
0
    def __init__(self,
                 n_uid,
                 n_mid,
                 EMBEDDING_DIM,
                 HIDDEN_SIZE,
                 BATCH_SIZE,
                 MEMORY_SIZE,
                 SEQ_LEN=400,
                 Mem_Induction=0,
                 Util_Reg=0,
                 use_negsample=False,
                 mask_flag=False):
        super(Model_MIMN, self).__init__(n_uid,
                                         n_mid,
                                         EMBEDDING_DIM,
                                         HIDDEN_SIZE,
                                         BATCH_SIZE,
                                         SEQ_LEN,
                                         use_negsample,
                                         Flag="MIMN")
        self.reg = Util_Reg

        def clear_mask_state(state, begin_state, begin_channel_rnn_state, mask,
                             cell, t):
            state["controller_state"] = (
                1 - tf.reshape(mask[:, t], (BATCH_SIZE, 1))
            ) * begin_state["controller_state"] + tf.reshape(
                mask[:, t], (BATCH_SIZE, 1)) * state["controller_state"]
            state["M"] = (1 - tf.reshape(
                mask[:, t],
                (BATCH_SIZE, 1, 1))) * begin_state["M"] + tf.reshape(
                    mask[:, t], (BATCH_SIZE, 1, 1)) * state["M"]
            state["key_M"] = (1 - tf.reshape(
                mask[:, t],
                (BATCH_SIZE, 1, 1))) * begin_state["key_M"] + tf.reshape(
                    mask[:, t], (BATCH_SIZE, 1, 1)) * state["key_M"]
            state["sum_aggre"] = (1 - tf.reshape(
                mask[:, t],
                (BATCH_SIZE, 1, 1))) * begin_state["sum_aggre"] + tf.reshape(
                    mask[:, t], (BATCH_SIZE, 1, 1)) * state["sum_aggre"]
            if Mem_Induction > 0:
                temp_channel_rnn_state = []
                for i in range(MEMORY_SIZE):
                    temp_channel_rnn_state.append(
                        cell.channel_rnn_state[i] *
                        tf.expand_dims(mask[:, t], axis=1) +
                        begin_channel_rnn_state[i] *
                        (1 - tf.expand_dims(mask[:, t], axis=1)))
                cell.channel_rnn_state = temp_channel_rnn_state
                temp_channel_rnn_output = []
                for i in range(MEMORY_SIZE):
                    temp_output = cell.channel_rnn_output[i] * tf.expand_dims(
                        mask[:, t], axis=1) + begin_channel_rnn_output[i] * (
                            1 - tf.expand_dims(self.mask[:, t], axis=1))
                    temp_channel_rnn_output.append(temp_output)
                cell.channel_rnn_output = temp_channel_rnn_output

            return state

        cell = mimn.MIMNCell(controller_units=HIDDEN_SIZE,
                             memory_size=MEMORY_SIZE,
                             memory_vector_dim=2 * EMBEDDING_DIM,
                             read_head_num=1,
                             write_head_num=1,
                             reuse=False,
                             output_dim=HIDDEN_SIZE,
                             clip_value=20,
                             batch_size=BATCH_SIZE,
                             mem_induction=Mem_Induction,
                             util_reg=Util_Reg)

        state = cell.zero_state(BATCH_SIZE, tf.float32)
        if Mem_Induction > 0:
            begin_channel_rnn_output = cell.channel_rnn_output
        else:
            begin_channel_rnn_output = 0.0

        begin_state = state
        self.state_list = [state]
        self.mimn_o = []
        for t in range(SEQ_LEN):
            output, state, temp_output_list = cell(self.item_his_eb[:, t, :],
                                                   state)
            if mask_flag:
                state = clear_mask_state(state, begin_state,
                                         begin_channel_rnn_output, self.mask,
                                         cell, t)
            self.mimn_o.append(output)
            self.state_list.append(state)

        self.mimn_o = tf.stack(self.mimn_o, axis=1)
        self.state_list.append(state)
        mean_memory = tf.reduce_mean(state['sum_aggre'], axis=-2)

        before_aggre = state['w_aggre']
        read_out, _, _ = cell(self.item_eb, state)

        if use_negsample:
            aux_loss_1 = self.auxiliary_loss(self.mimn_o[:, :-1, :],
                                             self.item_his_eb[:, 1:, :],
                                             self.neg_his_eb[:, 1:, :],
                                             self.mask[:, 1:],
                                             stag="bigru_0")
            self.aux_loss = aux_loss_1

        if self.reg:
            self.reg_loss = cell.capacity_loss(before_aggre)
        else:
            self.reg_loss = tf.zeros(1)

        if Mem_Induction == 1:
            channel_memory_tensor = tf.concat(temp_output_list, 1)
            multi_channel_hist = din_attention(self.item_eb,
                                               channel_memory_tensor,
                                               HIDDEN_SIZE,
                                               None,
                                               stag='pal')
            inp = tf.concat([
                self.item_eb, self.item_his_eb_sum, read_out,
                tf.squeeze(multi_channel_hist), mean_memory * self.item_eb
            ], 1)
        else:
            inp = tf.concat([
                self.item_eb, self.item_his_eb_sum, read_out,
                mean_memory * self.item_eb
            ], 1)

        self.build_fcn_net(inp, use_dice=False)
Ejemplo n.º 2
0
    def __init__(self,
                 uid_n,
                 item_n,
                 cate_n,
                 shop_n,
                 node_n,
                 product_n,
                 brand_n,
                 EMBEDDING_DIM,
                 HIDDEN_SIZE,
                 MEMORY_SIZE,
                 BATCH_SIZE,
                 SEQ_LEN=400,
                 Mem_Induction=0,
                 Util_Reg=0,
                 use_negsample=False,
                 mask_flag=False,
                 args=None):
        super(Model_MIMN, self).__init__(uid_n,
                                         item_n,
                                         cate_n,
                                         shop_n,
                                         node_n,
                                         product_n,
                                         brand_n,
                                         EMBEDDING_DIM,
                                         HIDDEN_SIZE,
                                         MEMORY_SIZE,
                                         BATCH_SIZE,
                                         SEQ_LEN,
                                         use_negsample,
                                         Flag="MIMN",
                                         args=args)
        logging.info(locals())

        self.reg = args.util_reg
        seq_reduce = args.mimn_seq_reduce
        Mem_Induction = args.mem_induction
        MEMORY_SIZE = args.memory_size

        def clear_mask_state(state, begin_state, begin_channel_rnn_state, mask,
                             cell, t):
            state["controller_state"] = (
                1 - tf.reshape(mask[:, t], (BATCH_SIZE, 1))
            ) * begin_state["controller_state"] + tf.reshape(
                mask[:, t], (BATCH_SIZE, 1)) * state["controller_state"]
            state["M"] = (1 - tf.reshape(
                mask[:, t],
                (BATCH_SIZE, 1, 1))) * begin_state["M"] + tf.reshape(
                    mask[:, t], (BATCH_SIZE, 1, 1)) * state["M"]
            state["key_M"] = (1 - tf.reshape(
                mask[:, t],
                (BATCH_SIZE, 1, 1))) * begin_state["key_M"] + tf.reshape(
                    mask[:, t], (BATCH_SIZE, 1, 1)) * state["key_M"]
            state["sum_aggre"] = (1 - tf.reshape(
                mask[:, t],
                (BATCH_SIZE, 1, 1))) * begin_state["sum_aggre"] + tf.reshape(
                    mask[:, t], (BATCH_SIZE, 1, 1)) * state["sum_aggre"]
            if Mem_Induction > 0:
                temp_channel_rnn_state = []
                for i in range(MEMORY_SIZE):
                    temp_channel_rnn_state.append(
                        cell.channel_rnn_state[i] *
                        tf.expand_dims(mask[:, t], axis=1) +
                        begin_channel_rnn_state[i] *
                        (1 - tf.expand_dims(mask[:, t], axis=1)))
                cell.channel_rnn_state = temp_channel_rnn_state
                temp_channel_rnn_output = []
                for i in range(MEMORY_SIZE):
                    temp_output = cell.channel_rnn_output[i] * tf.expand_dims(mask[:, t], axis=1) + \
                                  begin_channel_rnn_output[i] * (1 - tf.expand_dims(self.mask[:, t], axis=1))
                    temp_channel_rnn_output.append(temp_output)
                cell.channel_rnn_output = temp_channel_rnn_output

            return state

        inputs = self.inputs

        mimn_seq_split = args.mimn_seq_split if args.mimn_seq_split else args.long_seq_split
        if args and mimn_seq_split:
            seq_split = [(int(x.split(":")[0]), int(x.split(":")[1]))
                         for x in mimn_seq_split.split(",")]
            for idx, (left_idx, right_idx) in enumerate(seq_split):
                SEQ_LEN = abs(right_idx - left_idx)
                with tf.name_scope('MIMN_Layer_{0}'.format(idx)):
                    logging.info("mimn_layer {0}:{1}".format(
                        left_idx, right_idx))
                    mask = self.mask[:, left_idx:right_idx]

                item_his_eb = self.item_his_eb[:,
                                               left_idx:right_idx] * mask[:, :,
                                                                          None]
                if self.use_negsample:
                    neg_his_eb = self.neg_his_eb[:, left_idx:
                                                 right_idx] * mask[:, :, None]
                item_eb = self.item_eb
                if args.mimn_update_emb == 0:
                    item_his_eb = tf.stop_gradient(item_his_eb)
                    item_eb = tf.stop_gradient(item_eb)
                    if self.use_negsample:
                        neg_his_eb = tf.stop_gradient(neg_his_eb)
                memory_vector_dim = self.item_his_eb.get_shape().as_list()[-1]
                head_num = 1
                if args.head_num:
                    head_num = args.head_num

                cell = mimn.MIMNCell(controller_units=HIDDEN_SIZE,
                                     memory_size=MEMORY_SIZE,
                                     memory_vector_dim=memory_vector_dim,
                                     read_head_num=head_num,
                                     write_head_num=head_num,
                                     reuse=False,
                                     output_dim=HIDDEN_SIZE,
                                     clip_value=100,
                                     batch_size=BATCH_SIZE,
                                     mem_induction=Mem_Induction,
                                     util_reg=Util_Reg)

                state = cell.zero_state(BATCH_SIZE, tf.float32)
                if Mem_Induction > 0:
                    begin_channel_rnn_output = cell.channel_rnn_output
                else:
                    begin_channel_rnn_output = 0.0

                begin_state = state
                self.state_list = [state]
                self.mimn_o = []

                if args.mimn_seq_reduce:
                    logging.info("mimn_seq_reduce:{0}".format(
                        args.mimn_seq_reduce))
                    seq_reduce = args.mimn_seq_reduce
                    SEQ_LEN = int(SEQ_LEN / seq_reduce)

                    dim = item_his_eb.get_shape().as_list()[-1]
                    logging.info(dim)
                    item_his_eb = tf.reshape(item_his_eb,
                                             [-1, SEQ_LEN, seq_reduce, dim])
                    item_his_eb = tf.reduce_mean(item_his_eb, axis=-2)
                    logging.info(item_his_eb.get_shape())

                    mask = tf.reshape(mask, [-1, SEQ_LEN, seq_reduce])
                    mask = tf.reduce_sum(mask, axis=-1)
                    mask = tf.cast(tf.not_equal(mask, tf.zeros_like(mask)),
                                   tf.float32)

                for t in range(SEQ_LEN):
                    output, state, temp_output_list = cell(
                        item_his_eb[:, t, :], state)
                    if mask_flag:
                        state = clear_mask_state(state, begin_state,
                                                 begin_channel_rnn_output,
                                                 mask, cell, t)
                    self.mimn_o.append(output)
                    self.state_list.append(state)

                self.mimn_o = tf.stack(self.mimn_o, axis=1)
                self.state_list.append(state)
                mean_memory = tf.reduce_mean(state['sum_aggre'], axis=-2)

                before_aggre = state['w_aggre']
                read_out, _, _ = cell(item_eb, state)

                if use_negsample:
                    aux_loss_1 = self.auxiliary_loss(self.mimn_o[:, :-1, :],
                                                     item_his_eb[:, 1:, :],
                                                     neg_his_eb[:, 1:, :],
                                                     mask[:, 1:],
                                                     stag="bigru_0")
                    self.aux_loss = aux_loss_1

                if self.reg:
                    self.reg_loss = cell.capacity_loss(before_aggre)
                else:
                    self.reg_loss = tf.zeros(1)

                if Mem_Induction == 1:
                    channel_memory_tensor = tf.concat(temp_output_list, 1)
                    multi_channel_hist = din_attention(item_eb,
                                                       channel_memory_tensor,
                                                       HIDDEN_SIZE,
                                                       None,
                                                       stag='pal',
                                                       att_func=args.att_func)
                    inputs += [
                        read_out,
                        tf.squeeze(multi_channel_hist), mean_memory * item_eb
                    ]
                else:
                    inputs += [read_out, mean_memory * item_eb]

        if args and args.long_seq_split:
            seq_split = [(int(x.split(":")[0]), int(x.split(":")[1]))
                         for x in args.long_seq_split.split(",")]
            for idx, (left_idx, right_idx) in enumerate(seq_split):
                mask = self.mask[:, left_idx:right_idx]
                item_his_sum_emb = tf.reduce_sum(
                    self.item_his_eb[:, left_idx:right_idx] * mask[:, :, None],
                    1) / (tf.reduce_sum(mask, 1, keepdims=True) + 1.0)
            inputs.append(item_his_sum_emb)

        inp = tf.concat(inputs, 1)
        self.build_fcn_net(inp, use_dice=False)