예제 #1
0
class MODEL(nn.Module):
    def __init__(self,
                 n_question,
                 batch_size,
                 q_embed_dim,
                 qa_embed_dim,
                 memory_size,
                 memory_key_state_dim,
                 memory_value_state_dim,
                 final_fc_dim,
                 student_num=None):
        super(MODEL, self).__init__()
        self.n_question = n_question
        self.batch_size = batch_size
        self.q_embed_dim = q_embed_dim
        self.qa_embed_dim = qa_embed_dim
        self.memory_size = memory_size
        self.memory_key_state_dim = memory_key_state_dim
        self.memory_value_state_dim = memory_value_state_dim
        self.final_fc_dim = final_fc_dim
        self.student_num = student_num

        self.input_embed_linear = nn.Linear(self.q_embed_dim,
                                            self.q_embed_dim,
                                            bias=True)
        self.read_embed_linear = nn.Linear(self.memory_value_state_dim +
                                           self.q_embed_dim,
                                           self.final_fc_dim,
                                           bias=True)
        self.predict_linear = nn.Linear(self.final_fc_dim, 1, bias=True)
        self.init_memory_key = nn.Parameter(
            torch.randn(self.memory_size, self.memory_key_state_dim))
        nn.init.kaiming_normal_(self.init_memory_key)
        self.init_memory_value = nn.Parameter(
            torch.randn(self.memory_size, self.memory_value_state_dim))
        nn.init.kaiming_normal_(self.init_memory_value)
        self.mem = DKVMN(memory_size=self.memory_size,
                         memory_key_state_dim=self.memory_key_state_dim,
                         memory_value_state_dim=self.memory_value_state_dim,
                         init_memory_key=self.init_memory_key)

        memory_value = nn.Parameter(
            torch.cat([
                self.init_memory_value.unsqueeze(0) for _ in range(batch_size)
            ], 0).data)
        self.mem.init_value_memory(memory_value)

        # 题目序号从1开始
        # nn.embedding输入是一个下标的列标,输出是对应的嵌入
        self.q_embed = nn.Embedding(self.n_question + 1,
                                    self.q_embed_dim,
                                    padding_idx=0)
        self.qa_embed = nn.Embedding(2 * self.n_question + 1,
                                     self.qa_embed_dim,
                                     padding_idx=0)

    def init_params(self):
        nn.init.kaiming_normal_(self.predict_linear.weight)
        nn.init.kaiming_normal_(self.read_embed_linear.weight)
        nn.init.constant_(self.read_embed_linear.bias, 0)
        nn.init.constant_(self.predict_linear.bias, 0)
        # nn.init.constant(self.input_embed_linear.bias, 0)
        # nn.init.normal(self.input_embed_linear.weight, std=0.02)

    def init_embeddings(self):
        nn.init.kaiming_normal_(self.q_embed.weight)
        nn.init.kaiming_normal_(self.qa_embed.weight)

    def forward(self, q_data, qa_data, target, student_id=None):

        batch_size = q_data.shape[0]  #32
        seqlen = q_data.shape[1]  #200

        ## qt && (q,a) embedding
        q_embed_data = self.q_embed(q_data)
        qa_embed_data = self.qa_embed(qa_data)

        ## copy mk batch times for dkvmn
        memory_value = nn.Parameter(
            torch.cat([
                self.init_memory_value.unsqueeze(0) for _ in range(batch_size)
            ], 0).data)
        self.mem.init_value_memory(memory_value)

        ## slice data for seqlen times by axis 1
        # torch.chunk(tensor, chunk_num, dim)
        slice_q_data = torch.chunk(q_data, seqlen, 1)
        slice_q_embed_data = torch.chunk(q_embed_data, seqlen, 1)
        slice_qa_embed_data = torch.chunk(qa_embed_data, seqlen, 1)

        value_read_content_l = []
        input_embed_l = []

        for i in range(seqlen):
            ## Attention
            q = slice_q_embed_data[i].squeeze(1)
            correlation_weight = self.mem.attention(q)

            ## Read Process
            read_content = self.mem.read(correlation_weight)

            ## save intermedium data
            value_read_content_l.append(read_content)
            input_embed_l.append(q)

            ## Write Process
            qa = slice_qa_embed_data[i].squeeze(1)
            new_memory_value = self.mem.write(correlation_weight, qa)

        # Projection
        all_read_value_content = torch.cat(
            [value_read_content_l[i].unsqueeze(1) for i in range(seqlen)], 1)
        input_embed_content = torch.cat(
            [input_embed_l[i].unsqueeze(1) for i in range(seqlen)], 1)

        ## Project rt
        input_embed_content = input_embed_content.view(batch_size * seqlen, -1)
        input_embed_content = torch.tanh(
            self.input_embed_linear(input_embed_content))
        input_embed_content = input_embed_content.view(batch_size, seqlen, -1)

        ## Concat Read_Content and input_embedding_value
        predict_input = torch.cat(
            [all_read_value_content, input_embed_content], 2)
        read_content_embed = torch.tanh(
            self.read_embed_linear(predict_input.view(batch_size * seqlen,
                                                      -1)))

        pred = self.predict_linear(read_content_embed)
        # predicts = torch.cat([predict_logs[i] for i in range(seqlen)], 1)
        target_1d = target  # [batch_size * seq_len, 1]
        # mask = target_1d.ge(0)               # [batch_size * seq_len, 1]
        mask = q_data.gt(0).view(-1, 1)
        # pred_1d = predicts.view(-1, 1)           # [batch_size * seq_len, 1]
        pred_1d = pred.view(-1, 1)  # [batch_size * seq_len, 1]

        filtered_pred = torch.masked_select(pred_1d, mask)
        filtered_target = torch.masked_select(target_1d, mask)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            filtered_pred, filtered_target)

        return loss, torch.sigmoid(filtered_pred), filtered_target
예제 #2
0
class MODEL(nn.Module):
    def __init__(self,
                 n_question,
                 batch_size,
                 q_embed_dim,
                 qa_embed_dim,
                 memory_size,
                 memory_key_state_dim,
                 memory_value_state_dim,
                 final_fc_dim,
                 first_k,
                 gpu,
                 student_num=None):
        super(MODEL, self).__init__()
        self.n_question = n_question
        self.batch_size = batch_size
        self.q_embed_dim = q_embed_dim
        self.qa_embed_dim = qa_embed_dim
        self.memory_size = memory_size
        self.memory_key_state_dim = memory_key_state_dim
        self.memory_value_state_dim = memory_value_state_dim
        self.final_fc_dim = final_fc_dim
        self.student_num = student_num
        self.first_k = first_k

        self.read_embed_linear = nn.Linear(self.memory_value_state_dim +
                                           self.q_embed_dim,
                                           self.final_fc_dim,
                                           bias=True)
        # self.predict_linear = nn.Linear(self.memory_value_state_dim + self.q_embed_dim, 1, bias=True)
        self.init_memory_key = nn.Parameter(
            torch.randn(self.memory_size, self.memory_key_state_dim))
        nn.init.kaiming_normal_(self.init_memory_key)
        self.init_memory_value = nn.Parameter(
            torch.randn(self.memory_size, self.memory_value_state_dim))
        nn.init.kaiming_normal_(self.init_memory_value)

        # modify hop_lstm
        self.hop_lstm = nn.LSTM(input_size=self.memory_value_state_dim +
                                self.q_embed_dim,
                                hidden_size=64,
                                num_layers=1,
                                batch_first=True)
        # hidden_size = 64
        self.predict_linear = nn.Linear(64, 1, bias=True)

        self.mem = DKVMN(memory_size=self.memory_size,
                         memory_key_state_dim=self.memory_key_state_dim,
                         memory_value_state_dim=self.memory_value_state_dim,
                         init_memory_key=self.init_memory_key)

        memory_value = nn.Parameter(
            torch.cat([
                self.init_memory_value.unsqueeze(0) for _ in range(batch_size)
            ], 0).data)
        self.mem.init_value_memory(memory_value)

        # 题目序号从1开始
        # nn.embedding输入是一个下标的列标,输出是对应的嵌入
        self.q_embed = nn.Embedding(self.n_question + 1,
                                    self.q_embed_dim,
                                    padding_idx=0)
        self.a_embed = nn.Linear(2 * self.n_question + 1,
                                 self.qa_embed_dim,
                                 bias=True)
        # self.a_embed = nn.Linear(self.final_fc_dim + 1, self.qa_embed_dim, bias=True)

        # self.correlation_weight_list = []

        if gpu >= 0:
            self.device = torch.device('cuda', gpu)
        else:
            self.device = torch.device('cpu')

        print(
            "num_layers=1, hidden_size=64, a=0.075, b=0.088, c=1.00, triangular, onehot"
        )

    def init_params(self):
        nn.init.kaiming_normal_(self.predict_linear.weight)
        nn.init.kaiming_normal_(self.read_embed_linear.weight)
        nn.init.constant_(self.read_embed_linear.bias, 0)
        nn.init.constant_(self.predict_linear.bias, 0)

    def init_embeddings(self):
        nn.init.kaiming_normal_(self.q_embed.weight)

    # 方法2:权重向量的topk置1
    def identity_layer(self, correlation_weight, seqlen, k=1):
        batch_identity_indices = []
        correlation_weight = correlation_weight.view(self.batch_size * seqlen,
                                                     -1)

        # 把batch中每一格sequence中topk置1,其余置0
        _, indices = correlation_weight.topk(k, dim=1, largest=True)
        identity_matrix = torch.zeros(
            [self.batch_size * seqlen, self.memory_size])
        for i, m in enumerate(indices):
            identity_matrix[i, m] = 1

        identity_vector_batch = identity_matrix.view(self.batch_size * seqlen,
                                                     -1)

        unique_iv = torch.unique(identity_vector_batch, sorted=False, dim=0)
        self.unique_len = unique_iv.shape[0]

        # A^2
        iv_square_norm = torch.sum(torch.pow(identity_vector_batch, 2),
                                   dim=1,
                                   keepdim=True)
        iv_square_norm = iv_square_norm.repeat((1, self.unique_len))
        # B^2.T
        unique_iv_square_norm = torch.sum(torch.pow(unique_iv, 2),
                                          dim=1,
                                          keepdim=True)
        unique_iv_square_norm = unique_iv_square_norm.repeat(
            (1, self.batch_size * seqlen)).transpose(1, 0)
        # A * B.T
        iv_matrix_product = identity_vector_batch.mm(unique_iv.transpose(1, 0))
        # A^2 + B^2 - 2A*B.T
        iv_distances = iv_square_norm + unique_iv_square_norm - 2 * iv_matrix_product
        indices = (iv_distances == 0).nonzero()
        batch_identity_indices = indices[:, -1]

        return batch_identity_indices

    # 方法1:用三角隶属函数计算identity向量
    def triangular_layer(self,
                         correlation_weight,
                         seqlen,
                         a=0.075,
                         b=0.088,
                         c=1.00):
        batch_identity_indices = []

        # w'= max((w-a)/(b-a), (c-w)/(c-b))
        # min(w', 0)
        correlation_weight = correlation_weight.view(self.batch_size * seqlen,
                                                     -1)
        correlation_weight = torch.cat([
            correlation_weight[i] for i in range(correlation_weight.shape[0])
        ], 0).unsqueeze(0)
        correlation_weight = torch.cat([(correlation_weight - a) / (b - a),
                                        (c - correlation_weight) / (c - b)], 0)
        correlation_weight, _ = torch.min(correlation_weight, 0)
        w0 = torch.zeros(correlation_weight.shape[0]).to(self.device)
        correlation_weight = torch.cat(
            [correlation_weight.unsqueeze(0),
             w0.unsqueeze(0)], 0)
        correlation_weight, _ = torch.max(correlation_weight, 0)

        identity_vector_batch = torch.zeros(correlation_weight.shape[0]).to(
            self.device)

        # >=0.6的值置2,0.1-0.6的值置1,0.1以下的值置0
        # mask = correlation_weight.lt(0.1)
        identity_vector_batch = identity_vector_batch.masked_fill(
            correlation_weight.lt(0.1), 0)
        # mask = correlation_weight.ge(0.1)
        identity_vector_batch = identity_vector_batch.masked_fill(
            correlation_weight.ge(0.1), 1)
        # mask = correlation_weight.ge(0.6)
        _identity_vector_batch = identity_vector_batch.masked_fill(
            correlation_weight.ge(0.6), 2)

        # identity_vector_batch = torch.chunk(identity_vector_batch.view(self.batch_size, -1), self.batch_size, 0)

        # 输入:_identity_vector_batch
        # 输出:indices
        identity_vector_batch = _identity_vector_batch.view(
            self.batch_size * seqlen, -1)

        unique_iv = torch.unique(identity_vector_batch, sorted=False, dim=0)
        self.unique_len = unique_iv.shape[0]

        # A^2
        iv_square_norm = torch.sum(torch.pow(identity_vector_batch, 2),
                                   dim=1,
                                   keepdim=True)
        iv_square_norm = iv_square_norm.repeat((1, self.unique_len))
        # B^2.T
        unique_iv_square_norm = torch.sum(torch.pow(unique_iv, 2),
                                          dim=1,
                                          keepdim=True)
        unique_iv_square_norm = unique_iv_square_norm.repeat(
            (1, self.batch_size * seqlen)).transpose(1, 0)
        # A * B.T
        iv_matrix_product = identity_vector_batch.mm(unique_iv.transpose(1, 0))
        # A^2 + B^2 - 2A*B.T
        iv_distances = iv_square_norm + unique_iv_square_norm - 2 * iv_matrix_product
        indices = (iv_distances == 0).nonzero()
        batch_identity_indices = indices[:, -1]

        return batch_identity_indices

    def forward(self, q_data, qa_data, a_data, target, student_id=None):

        batch_size = q_data.shape[0]  #32
        seqlen = q_data.shape[1]  #200

        ## qt && (q,a) embedding
        q_embed_data = self.q_embed(q_data)

        # modify 生成每道题对应的yt onehot向量
        a_onehot_array = []
        for i in range(a_data.shape[0]):
            for j in range(a_data.shape[1]):
                a_onehot = np.zeros(self.n_question + 1)
                index = a_data[i][j]
                if index > 0:
                    a_onehot[index] = 1
                a_onehot_array.append(a_onehot)
        a_onehot_content = torch.cat([
            torch.Tensor(a_onehot_array[i]).unsqueeze(0)
            for i in range(len(a_onehot_array))
        ], 0)
        a_onehot_content = a_onehot_content.view(batch_size, seqlen,
                                                 -1).to(self.device)

        ## copy mk batch times for dkvmn
        memory_value = nn.Parameter(
            torch.cat([
                self.init_memory_value.unsqueeze(0) for _ in range(batch_size)
            ], 0).data)
        self.mem.init_value_memory(memory_value)

        ## slice data for seqlen times by axis 1
        slice_q_data = torch.chunk(q_data, seqlen, 1)
        slice_q_embed_data = torch.chunk(q_embed_data, seqlen, 1)

        # modify
        slice_a_onehot_content = torch.chunk(a_onehot_content, seqlen, 1)
        # slice_a = torch.chunk(a_data, seqlen, 1)

        value_read_content_l = []
        input_embed_l = []
        correlation_weight_list = []

        # modify
        f_t = []

        # (n_layers,batch_size,hidden_dim)
        init_h = torch.randn(1, self.batch_size, 64).to(self.device)
        init_c = torch.randn(1, self.batch_size, 64).to(self.device)

        for i in range(seqlen):
            ## Attention
            q = slice_q_embed_data[i].squeeze(1)
            correlation_weight = self.mem.attention(q)

            ## Read Process
            read_content = self.mem.read(correlation_weight)

            # modify
            correlation_weight_list.append(correlation_weight)

            ## save intermedium data
            value_read_content_l.append(read_content)
            input_embed_l.append(q)

            # modify
            batch_predict_input = torch.cat([read_content, q], 1)
            f = self.read_embed_linear(batch_predict_input)
            f_t.append(batch_predict_input)

            # 写入value矩阵的输入为[yt, ft],onehot向量和ft向量拼接
            onehot = slice_a_onehot_content[i].squeeze(1)
            write_embed = torch.cat([onehot, f], 1)

            # 写入value矩阵的输入为[ft, yt],ft直接和题目对错(0或1)拼接
            # write_embed = torch.cat([f, slice_a[i].float()], 1)

            write_embed = self.a_embed(write_embed)
            new_memory_value = self.mem.write(correlation_weight, write_embed)

        # modify
        correlation_weight_matrix = torch.cat(
            [correlation_weight_list[i].unsqueeze(1) for i in range(seqlen)],
            1)
        identity_index_list = self.triangular_layer(correlation_weight_matrix,
                                                    seqlen)
        # identity_index_list = self.identity_layer(correlation_weight_matrix, seqlen)
        identity_index_list = identity_index_list.view(self.batch_size, seqlen)
        # identity_index_list = identity_index_list[:, self.first_k:]  # 前k个不进行预测

        # identity_index_list = torch.cat([identity_index_list[i].unsqueeze(1) for i in range(seqlen)], 1)
        f_t = torch.cat([f_t[i].unsqueeze(1) for i in range(seqlen)], 1)
        # f_t = f_t[:, self.first_k:]  # 前k个不进行预测
        target_seqlayer = target.view(batch_size, seqlen, -1)
        # target_seqlayer = target_seqlayer[:, self.first_k:]  # 前k个不进行预测

        target_sequence = []
        pred_sequence = []

        for idx in range(self.unique_len):
            # start = time.time()
            hop_lstm_input = []
            hop_lstm_target = []
            max_seq = 1
            zero_count = 0
            for i in range(self.batch_size):
                # 获取每个sequence中和当前要进行预测的identity向量对应的题目在矩阵中的index
                index = list((identity_index_list[i, :] == idx).nonzero())
                max_seq = max(max_seq, len(index))
                if len(index) == 0:
                    hop_lstm_input.append(
                        torch.zeros([
                            1, self.memory_value_state_dim + self.q_embed_dim
                        ]))
                    hop_lstm_target.append(torch.full([1, 1], -1))
                    zero_count += 1
                    continue
                else:
                    index = torch.LongTensor(index).to(self.device)
                    hop_lstm_target_slice = torch.index_select(
                        target_seqlayer[i, :, :], 0, index)
                    hop_lstm_input_slice = torch.index_select(
                        f_t[i, :, :], 0, index)
                    hop_lstm_input.append(hop_lstm_input_slice)
                    hop_lstm_target.append(hop_lstm_target_slice)

            if zero_count == 32:
                continue

            # 给输入矩阵和target矩阵做padding
            for i in range(self.batch_size):
                x = torch.zeros(
                    [max_seq, self.memory_value_state_dim + self.q_embed_dim])
                x[:len(hop_lstm_input[i]), :] = hop_lstm_input[i]
                hop_lstm_input[i] = x
                y = torch.full([max_seq, 1], -1)
                y[:len(hop_lstm_target[i]), :] = hop_lstm_target[i]
                hop_lstm_target[i] = y

            # hop lstm进行预测
            hop_lstm_input = torch.cat([
                hop_lstm_input[i].unsqueeze(0) for i in range(self.batch_size)
            ], 0).to(self.device)
            hop_lstm_target = torch.cat([
                hop_lstm_target[i].unsqueeze(0) for i in range(self.batch_size)
            ], 0)

            hop_lstm_output, _ = self.hop_lstm(hop_lstm_input,
                                               (init_h, init_c))
            pred = self.predict_linear(hop_lstm_output)
            pred = pred.view(self.batch_size * max_seq, -1)
            hop_lstm_target = hop_lstm_target.view(self.batch_size * max_seq,
                                                   -1).to(self.device)
            mask = hop_lstm_target.ge(0)
            hop_lstm_target = torch.masked_select(hop_lstm_target, mask)
            pred = torch.sigmoid(torch.masked_select(pred, mask))
            target_sequence.append(hop_lstm_target)
            pred_sequence.append(pred)

            # 在训练阶段对每个identity向量对应的lstm分别进行反向传播
            if self.training is True:
                subsequence_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                    pred, hop_lstm_target)
                subsequence_loss.backward(retain_graph=True)

        # 计算一个batch全部题目的loss
        target_sequence = torch.cat(
            [target_sequence[i] for i in range(len(target_sequence))], 0)
        pred_sequence = torch.cat(
            [pred_sequence[i] for i in range(len(pred_sequence))], 0)

        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            pred_sequence, target_sequence)

        return loss, pred_sequence, target_sequence
예제 #3
0
    def sym_gen(self):
        ### TODO input variable 'q_data'
        q_data = mx.sym.Variable('q_data', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size)
        ### TODO input variable 'qa_data'
        qa_data = mx.sym.Variable('qa_data', shape=(self.seqlen, self.batch_size))  # (seqlen, batch_size)
        ### TODO input variable 'target'
        target = mx.sym.Variable('target', shape=(self.seqlen, self.batch_size)) #(seqlen, batch_size)

        ### Initialize Memory
        init_memory_key = mx.sym.Variable('init_memory_key_weight')
        init_memory_value = mx.sym.Variable('init_memory_value',
                                            shape=(self.memory_size, self.memory_value_state_dim),
                                            init=mx.init.Normal(0.1)) # (self.memory_size, self.memory_value_state_dim)
        init_memory_value = mx.sym.broadcast_to(mx.sym.expand_dims(init_memory_value, axis=0),
                                                shape=(self.batch_size, self.memory_size, self.memory_value_state_dim))

        mem = DKVMN(memory_size=self.memory_size,
                   memory_key_state_dim=self.memory_key_state_dim,
                   memory_value_state_dim=self.memory_value_state_dim,
                   init_memory_key=init_memory_key,
                   init_memory_value=init_memory_value,
                   name="DKVMN")


        ### embedding
        q_data = mx.sym.BlockGrad(q_data)
        q_embed_data = mx.sym.Embedding(data=q_data, input_dim=self.n_question+1,
                                        output_dim=self.q_embed_dim, name='q_embed')
        slice_q_embed_data = mx.sym.SliceChannel(q_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True)

        qa_data = mx.sym.BlockGrad(qa_data)
        qa_embed_data = mx.sym.Embedding(data=qa_data, input_dim=self.n_question*2+1,
                                         output_dim=self.qa_embed_dim, name='qa_embed')
        slice_qa_embed_data = mx.sym.SliceChannel(qa_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True)

        value_read_content_l = []
        input_embed_l = []
        
        readDict = {object :[]}
        
        for i in range(self.seqlen):
            ## Attention
            
            q = mx.sym.L2Normalization(slice_q_embed_data[i], mode='instance')
            correlation_weight = mem.attention(q)
            
            ## Read Process
            read_content = mem.read(correlation_weight) #Shape (batch_size, memory_state_dim)
            
            
            ### save intermedium data [OLD]
            value_read_content_l.append(read_content)
          
            input_embed_l.append(q)
            
             ## Write Process
            #qa = slice_qa_embed_data[i] 
            #new_memory_value = mem.write(correlation_weight, mx.sym.Concat(qa , read_content))
            qa = mx.sym.concat(mx.sym.L2Normalization(read_content, mode='instance'),mx.sym.L2Normalization(slice_qa_embed_data[i], mode='instance'))
            #qa=mx.sym.L2Normalization(slice_qa_embed_data[i], mode='instance')
            new_memory_value = mem.write(correlation_weight,qa)

        #================================[ Cluster related read_contents based on fuzzy representation] ==============
        for i in range(0,len(value_read_content_l)):
            current_fuzz_rep = mx.symbol.Custom(data=value_read_content_l[i], name='fuzzkey', op_type='fuzzify')
            related = [value_read_content_l[i]]
            for j in range(0,len(value_read_content_l)):
                if i != j:
                    tmp_fuzz = mx.symbol.Custom(data=value_read_content_l[j], name='fuzzkey', op_type='fuzzify')
                    if current_fuzz_rep.tojson() == tmp_fuzz.tojson():
                        related.append(value_read_content_l[j])
                        
            value_read_content_l[i] = mx.sym.Reshape(data=mx.sym.RNN(data=related,state_size=self.memory_value_state_dim,num_layers=2,mode ='lstm',p =0.2), # Shape (batch_size, 1, memory_state_dim)
                                 shape=(-1,self.memory_value_state_dim)) 
                        
        #=================================================================================
        
        all_read_value_content = mx.sym.Concat(*value_read_content_l, num_args=self.seqlen, dim=0)

        input_embed_content = mx.sym.Concat(*input_embed_l, num_args=self.seqlen, dim=0) 
        input_embed_content = mx.sym.FullyConnected(data=mx.sym.L2Normalization(input_embed_content, mode='instance'), num_hidden=64, name="input_embed_content")
        input_embed_content = mx.sym.Activation(data=mx.sym.L2Normalization(input_embed_content, mode='instance'), act_type='tanh', name="input_embed_content_tanh")


        read_content_embed = mx.sym.FullyConnected(data=mx.sym.Concat(mx.sym.L2Normalization(all_read_value_content, mode='instance'), mx.sym.L2Normalization(input_embed_content, mode='instance'), num_args=2, dim=1),
                                                   num_hidden=self.final_fc_dim, name="read_content_embed") 
       
        read_content_embed = mx.sym.Activation(data= mx.sym.L2Normalization(read_content_embed, mode='instance'), act_type='tanh', name="read_content_embed_tanh")  
        
        #================================================[ Updated for F value]====================================
#        for i in range(self.seqlen):
#           ## Write Process
#           qa = mx.symbol.batch_dot(slice_qa_embed_data[i],read_content_embed)
#           #qa = mx.sym.Concat(slice_qa_embed_data[i],read_content_embed)
#           #qa = read_content_embed
#           new_memory_value = mem.write(correlation_weight, qa)

        #==========================================================================================================
         
        pred = mx.sym.FullyConnected(data=mx.sym.L2Normalization(read_content_embed, mode='instance'), num_hidden=1, name="final_fc")

        pred_prob = logistic_regression_mask_output(data=mx.sym.Reshape(pred, shape=(-1, )),
                                                    label=mx.sym.Reshape(data=target, shape=(-1,)),
                                                    ignore_label=-1., name='final_pred')
        return mx.sym.Group([pred_prob])
예제 #4
0
class Model():
    def __init__(self, args, sess, name='KT'):
        self.args = args
        self.name = name
        self.sess = sess
        self.sess.run(tf.global_variables_initializer())
        self.create_model()
        # if self.load():
        # 	print('CKPT Loaded')
        # else:
        # 	raise Exception('CKPT need')

    def create_model(self):
        # 'seq_len' means question sequences
        self.q_data = tf.placeholder(tf.int32,
                                     [self.args.batch_size, self.args.seq_len],
                                     name='q_data')
        self.qa_data = tf.placeholder(
            tf.int32, [self.args.batch_size, self.args.seq_len],
            name='qa_data')
        self.target = tf.placeholder(tf.float32,
                                     [self.args.batch_size, self.args.seq_len],
                                     name='target')
        self.kg = tf.placeholder(tf.int32,
                                 [self.args.batch_size, self.args.seq_len, 3],
                                 name='knowledge_tag')
        self.kg_hot = tf.placeholder(
            tf.float32, [self.args.batch_size, self.args.seq_len, 188],
            name='knowledge_hot')
        self.timebin = tf.placeholder(
            tf.int32, [self.args.batch_size, self.args.seq_len])
        self.diff = tf.placeholder(tf.int32,
                                   [self.args.batch_size, self.args.seq_len])
        self.guan = tf.placeholder(tf.int32,
                                   [self.args.batch_size, self.args.seq_len])

        with tf.variable_scope('Memory'):
            init_memory_key = tf.get_variable('key', [self.args.memory_size, self.args.memory_key_state_dim], \
             initializer=tf.truncated_normal_initializer(stddev=0.1))
            init_memory_value = tf.get_variable('value', [self.args.memory_size,self.args.memory_value_state_dim], \
             initializer=tf.truncated_normal_initializer(stddev=0.1))
        with tf.variable_scope('time'):
            time_embed_mtx = tf.get_variable('timebin', [12, self.args.memory_value_state_dim],\
             initializer=tf.truncated_normal_initializer(stddev=0.1))
        with tf.variable_scope('diff'):
            guan_embed_mtx = tf.get_variable('diff', [12, self.args.memory_value_state_dim],\
             initializer=tf.truncated_normal_initializer(stddev=0.1))

        with tf.variable_scope('gate'):
            diff_embed_mtx = tf.get_variable('gate', [12, self.args.memory_value_state_dim],\
             initializer=tf.truncated_normal_initializer(stddev=0.1))

        init_memory_value = tf.tile(tf.expand_dims(init_memory_value, 0),
                                    tf.stack([self.args.batch_size, 1, 1]))
        print(init_memory_value.get_shape())

        self.memory = DKVMN(self.args.memory_size, self.args.memory_key_state_dim, \
          self.args.memory_value_state_dim, init_memory_key=init_memory_key, init_memory_value=init_memory_value, batch_size=self.args.batch_size, name='DKVMN')

        with tf.variable_scope('Embedding'):
            # A
            q_embed_mtx = tf.get_variable('q_embed', [self.args.n_questions+1, self.args.memory_key_state_dim],\
             initializer=tf.truncated_normal_initializer(stddev=0.1))
            # B
            qa_embed_mtx = tf.get_variable(
                'qa_embed', [
                    2 * self.args.n_questions + 1,
                    self.args.memory_value_state_dim
                ],
                initializer=tf.truncated_normal_initializer(stddev=0.1))

        q_embed_data = tf.nn.embedding_lookup(q_embed_mtx, self.q_data)
        slice_q_embed_data = tf.split(q_embed_data, self.args.seq_len, 1)

        qa_embed_data = tf.nn.embedding_lookup(qa_embed_mtx, self.qa_data)
        slice_qa_embed_data = tf.split(qa_embed_data, self.args.seq_len, 1)

        time_embedding = tf.nn.embedding_lookup(time_embed_mtx, self.timebin)
        slice_time_embedding = tf.split(time_embedding, self.args.seq_len, 1)

        guan_embedding = tf.nn.embedding_lookup(diff_embed_mtx, self.diff)
        slice_guan_embedding = tf.split(guan_embedding, self.args.seq_len, 1)

        diff_embedding = tf.nn.embedding_lookup(diff_embed_mtx, self.diff)
        slice_diff_embedding = tf.split(diff_embedding, self.args.seq_len, 1)

        slice_kg = tf.split(self.kg, self.args.seq_len, 1)

        slice_kg_hot = tf.split(self.kg_hot, self.args.seq_len, 1)

        reuse_flag = False

        prediction = list()

        # Logics
        for i in range(self.args.seq_len):
            # To reuse linear vectors
            if i != 0:
                reuse_flag = True

            q = tf.squeeze(slice_q_embed_data[i], 1)
            qa = tf.squeeze(slice_qa_embed_data[i], 1)
            kg = tf.squeeze(slice_kg[i], 1)
            kg_hot = tf.squeeze(slice_kg_hot[i], 1)
            dotime = tf.squeeze(slice_time_embedding[i], 1)
            dodiff = tf.squeeze(slice_diff_embedding[i], 1)
            doguan = tf.squeeze(slice_guan_embedding[i], 1)

            self.correlation_weight = self.memory.attention(q, kg, kg_hot)

            # # Read process, [batch size, memory value state dim]
            self.read_content = self.memory.read(self.correlation_weight)

            mastery_level_prior_difficulty = tf.concat(
                [self.read_content, q, doguan], 1)

            # f_t
            summary_vector = tf.tanh(
                operations.linear(mastery_level_prior_difficulty,
                                  self.args.final_fc_dim,
                                  name='Summary_Vector',
                                  reuse=reuse_flag))
            # p_t
            pred_logits = operations.linear(summary_vector,
                                            1,
                                            name='Prediction',
                                            reuse=reuse_flag)

            prediction.append(pred_logits)

            qa_time = tf.concat([qa, dotime], axis=1)

            self.new_memory_value = self.memory.write(self.correlation_weight,
                                                      qa_time,
                                                      reuse=reuse_flag)

# 'prediction' : seq_len length list of [batch size ,1], make it [batch size, seq_len] tensor
# tf.stack convert to [batch size, seq_len, 1]
        self.pred_logits = tf.reshape(tf.stack(
            prediction, axis=1), [self.args.batch_size, self.args.seq_len])

        # Define loss : standard cross entropy loss, need to ignore '-1' label example
        # Make target/label 1-d array
        target_1d = tf.reshape(self.target, [-1])
        pred_logits_1d = tf.reshape(self.pred_logits, [-1])
        index = tf.where(
            tf.not_equal(target_1d, tf.constant(-1., dtype=tf.float32)))
        # tf.gather(params, indices) : Gather slices from params according to indices
        filtered_target = tf.gather(target_1d, index)
        filtered_logits = tf.gather(pred_logits_1d, index)

        self.loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=filtered_logits,
                                                    labels=filtered_target))
        self.pred = tf.sigmoid(self.pred_logits)

        # Optimizer : SGD + MOMENTUM with learning rate decay
        self.global_step = tf.Variable(0, trainable=False)
        self.lr = tf.placeholder(tf.float32, [], name='learning_rate')

        optimizer = tf.train.MomentumOptimizer(self.lr, self.args.momentum)
        grads, vrbs = zip(*optimizer.compute_gradients(self.loss))
        grad, _ = tf.clip_by_global_norm(grads, self.args.maxgradnorm)
        self.train_op = optimizer.apply_gradients(zip(grad, vrbs),
                                                  global_step=self.global_step)
        self.tr_vrbs = tf.trainable_variables()
        self.params = {}
        for i in self.tr_vrbs:
            print(i.name)
            self.params[i.name] = tf.get_default_graph().get_tensor_by_name(
                i.name)
        self.saver = tf.train.Saver()

    def getParam(self):
        """
		Get parameters of DKVMN-CA model
		:return:
		"""
        params = self.sess.run([self.params])
        with open('good_future.pkl', 'wb') as f:
            pickle.dump(params, f)

    def train(self, train_q_data, train_qa_data, valid_q_data, valid_qa_data,
              train_kg_data, valid_kg_data, train_kgnum_data, valid_kgnum_data,
              traintime, validtime, trainguan, validguan, traindiff,
              validdiff):
        """

		:param train_q_data: exercises ID
		:param train_qa_data: exercises ID and answer result
		:param train_kg_data: one-hot form of knowledge concepts
		:param train_kgnum_data: knowledge concepts tags
		:param traintime: completion time
		:param trainguan: the gate of exercise
		:param traindiff: the difficulty of exercise
		"""
        shuffle_index = np.random.permutation(train_q_data.shape[0])
        q_data_shuffled = train_q_data[shuffle_index]
        qa_data_shuffled = train_qa_data[shuffle_index]
        kg_shuffled = train_kgnum_data[shuffle_index]
        kghot_shuffled = train_kg_data[shuffle_index]
        time_shuffled = traintime[shuffle_index]
        guan_shuffled = trainguan[shuffle_index]
        diff_shuffled = traindiff[shuffle_index]
        training_step = train_q_data.shape[0] // self.args.batch_size
        self.sess.run(tf.global_variables_initializer())

        self.train_count = 0
        if self.args.init_from:
            if self.load():
                print('Checkpoint_loaded')
            else:
                print('No checkpoint')
        else:
            if os.path.exists(
                    os.path.join(self.args.checkpoint_dir, self.model_dir)):
                try:
                    shutil.rmtree(
                        os.path.join(self.args.checkpoint_dir, self.model_dir))
                    shutil.rmtree(
                        os.path.join(self.args.log_dir,
                                     self.mode_dir + '.csv'))
                except (FileNotFoundError, IOError) as e:
                    print('[Delete Error] %s - %s' % (e.filename, e.strerror))

        best_valid_auc = 0

        # Training
        for epoch in range(0, self.args.num_epochs):
            if self.args.show:
                bar.next()

            pred_list = list()
            target_list = list()
            epoch_loss = 0

            for steps in range(training_step):
                # [batch size, seq_len]
                q_batch_seq = q_data_shuffled[steps *
                                              self.args.batch_size:(steps +
                                                                    1) *
                                              self.args.batch_size, :]
                qa_batch_seq = qa_data_shuffled[steps *
                                                self.args.batch_size:(steps +
                                                                      1) *
                                                self.args.batch_size, :]
                kg_batch_seq = kg_shuffled[steps *
                                           self.args.batch_size:(steps + 1) *
                                           self.args.batch_size, :]
                kghot_batch_seq = kghot_shuffled[steps *
                                                 self.args.batch_size:(steps +
                                                                       1) *
                                                 self.args.batch_size, :]
                time_batch_seq = time_shuffled[steps *
                                               self.args.batch_size:(steps +
                                                                     1) *
                                               self.args.batch_size, :]
                guan_batch_seq = guan_shuffled[steps *
                                               self.args.batch_size:(steps +
                                                                     1) *
                                               self.args.batch_size, :]
                diff_batch_seq = diff_shuffled[steps *
                                               self.args.batch_size:(steps +
                                                                     1) *
                                               self.args.batch_size, :]
                # qa : exercise index + answer(0 or 1)*exercies_number
                # right : 1, wrong : 0, padding : -1
                target = qa_batch_seq[:, :]
                # Make integer type to calculate target
                target = target.astype(np.int)
                target_batch = (target - 1) // self.args.n_questions
                target_batch = target_batch.astype(np.float)

                feed_dict = {
                    self.kg: kg_batch_seq,
                    self.q_data: q_batch_seq,
                    self.qa_data: qa_batch_seq,
                    self.target: target_batch,
                    self.kg_hot: kghot_batch_seq,
                    self.lr: self.args.initial_lr,
                    self.timebin: time_batch_seq,
                    self.diff: diff_batch_seq,
                    self.guan: guan_batch_seq
                }

                loss_, pred_, _, = self.sess.run(
                    [self.loss, self.pred, self.train_op], feed_dict=feed_dict)

                right_target = np.asarray(target_batch).reshape(-1, 1)
                right_pred = np.asarray(pred_).reshape(-1, 1)

                right_index = np.flatnonzero(right_target != -1.).tolist()

                pred_list.append(right_pred[right_index])
                target_list.append(right_target[right_index])

                epoch_loss += loss_

            if self.args.show:
                bar.finish()

            all_pred = np.concatenate(pred_list, axis=0)
            all_target = np.concatenate(target_list, axis=0)

            # Compute metrics
            self.auc = metrics.roc_auc_score(all_target, all_pred)
            # Extract elements with boolean index
            # Make '1' for elements higher than 0.5
            # Make '0' for elements lower than 0.5
            all_pred[all_pred > 0.5] = 1
            all_pred[all_pred <= 0.5] = 0

            self.accuracy = metrics.accuracy_score(all_target, all_pred)

            epoch_loss = epoch_loss / training_step
            print('Epoch %d/%d, loss : %3.5f, auc : %3.5f, accuracy : %3.5f' %
                  (epoch + 1, self.args.num_epochs, epoch_loss, self.auc,
                   self.accuracy))
            self.write_log(epoch=epoch + 1,
                           auc=self.auc,
                           accuracy=self.accuracy,
                           loss=epoch_loss,
                           name='training_')

            valid_steps = valid_q_data.shape[0] // self.args.batch_size
            valid_pred_list = list()
            valid_target_list = list()
            for s in range(valid_steps):
                # Validation
                valid_q = valid_q_data[s * self.args.batch_size:(s + 1) *
                                       self.args.batch_size, :]
                valid_qa = valid_qa_data[s * self.args.batch_size:(s + 1) *
                                         self.args.batch_size, :]
                valid_kg = valid_kgnum_data[s * self.args.batch_size:(s + 1) *
                                            self.args.batch_size, :]
                valid_hot_kg = valid_kg_data[s * self.args.batch_size:(s + 1) *
                                             self.args.batch_size, :]
                valid_time = validtime[s * self.args.batch_size:(s + 1) *
                                       self.args.batch_size, :]
                valid_guan = validguan[s * self.args.batch_size:(s + 1) *
                                       self.args.batch_size, :]
                valid_diff = validdiff[s * self.args.batch_size:(s + 1) *
                                       self.args.batch_size, :]
                # right : 1, wrong : 0, padding : -1
                valid_target = (valid_qa - 1) // self.args.n_questions
                valid_feed_dict = {
                    self.kg: valid_kg,
                    self.q_data: valid_q,
                    self.qa_data: valid_qa,
                    self.kg_hot: valid_hot_kg,
                    self.target: valid_target,
                    self.timebin: valid_time,
                    self.guan: valid_guan,
                    self.diff: valid_diff
                }
                valid_loss, valid_pred = self.sess.run(
                    [self.loss, self.pred], feed_dict=valid_feed_dict)
                # Same with training set
                valid_right_target = np.asarray(valid_target).reshape(-1, )
                valid_right_pred = np.asarray(valid_pred).reshape(-1, )
                valid_right_index = np.flatnonzero(
                    valid_right_target != -1).tolist()
                valid_target_list.append(valid_right_target[valid_right_index])
                valid_pred_list.append(valid_right_pred[valid_right_index])

            all_valid_pred = np.concatenate(valid_pred_list, axis=0)
            all_valid_target = np.concatenate(valid_target_list, axis=0)

            valid_auc = metrics.roc_auc_score(all_valid_target, all_valid_pred)
            # For validation accuracy
            stop = 0
            all_valid_pred[all_valid_pred > 0.5] = 1
            all_valid_pred[all_valid_pred <= 0.5] = 0
            valid_accuracy = metrics.accuracy_score(all_valid_target,
                                                    all_valid_pred)
            print('Epoch %d/%d, valid auc : %3.5f, valid accuracy : %3.5f' %
                  (epoch + 1, self.args.num_epochs, valid_auc, valid_accuracy))
            # Valid log
            self.write_log(epoch=epoch + 1,
                           auc=valid_auc,
                           accuracy=valid_accuracy,
                           loss=valid_loss,
                           name='valid_')

            if valid_auc > best_valid_auc:
                print('%3.4f to %3.4f' % (best_valid_auc, valid_auc))
                best_valid_auc = valid_auc
                best_acc = valid_accuracy
                best_epoch = epoch + 1
                # self.save(best_epoch)
            else:
                if epoch - best_epoch >= 2:

                    with open(self.args.dataset + 'concat', 'a') as f:
                        f.write('auc:' + str(best_valid_auc) + ',acc:' +
                                str(best_acc) + '\n')
                    self.args.count += 1
                    break

    @property
    def model_dir(self):
        return '{}_{}batch_{}epochs'.format(
            self.args.dataset + str(self.args.count), self.args.batch_size,
            self.args.num_epochs)

    def load(self):
        checkpoint_dir = os.path.join(self.args.checkpoint_dir, self.model_dir)
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.train_count = int(ckpt_name.split('-')[-1])
            self.saver.restore(self.sess,
                               os.path.join(checkpoint_dir, ckpt_name))
            return True
        else:
            return False

    def save(self, global_step):
        model_name = 'DKVMN'
        checkpoint_dir = os.path.join(self.args.checkpoint_dir, self.model_dir)
        if not os.path.exists(checkpoint_dir):
            os.mkdir(checkpoint_dir)
        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir, model_name),
                        global_step=global_step)
        print('Save checkpoint at %d' % (global_step + 1))

    # Log file
    def write_log(self, auc, accuracy, loss, epoch, name='training_'):
        log_path = os.path.join(self.args.log_dir,
                                name + self.model_dir + '.csv')
        if not os.path.exists(log_path):
            self.log_file = open(log_path, 'w')
            self.log_file.write('Epoch\tAuc\tAccuracy\tloss\n')
        else:
            self.log_file = open(log_path, 'a')

        self.log_file.write(
            str(epoch) + '\t' + str(auc) + '\t' + str(accuracy) + '\t' +
            str(loss) + '\n')
        self.log_file.flush()
예제 #5
0
    def sym_gen(self):
        ### TODO input variable 'q_data'
        q_data = mx.sym.Variable('q_data', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size)
        ### TODO input variable 'qa_data'
        qa_data = mx.sym.Variable('qa_data', shape=(self.seqlen, self.batch_size))  # (seqlen, batch_size)
        ### TODO input variable 'target'
        target = mx.sym.Variable('target', shape=(self.seqlen, self.batch_size)) #(seqlen, batch_size)

        ### Initialize Memory
        init_memory_key = mx.sym.Variable('init_memory_key_weight')
        init_memory_value = mx.sym.Variable('init_memory_value',
                                            shape=(self.memory_size, self.memory_value_state_dim),
                                            init=mx.init.Normal(0.1)) # (self.memory_size, self.memory_value_state_dim)
        init_memory_value = mx.sym.broadcast_to(mx.sym.expand_dims(init_memory_value, axis=0),
                                                shape=(self.batch_size, self.memory_size, self.memory_value_state_dim))

        mem = DKVMN(memory_size=self.memory_size,
                   memory_key_state_dim=self.memory_key_state_dim,
                   memory_value_state_dim=self.memory_value_state_dim,
                   init_memory_key=init_memory_key,
                   init_memory_value=init_memory_value,
                   name="DKVMN")


        ### embedding
        q_data = mx.sym.BlockGrad(q_data)
        q_embed_data = mx.sym.Embedding(data=q_data, input_dim=self.n_question+1,
                                        output_dim=self.q_embed_dim, name='q_embed')
        slice_q_embed_data = mx.sym.SliceChannel(q_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True)

        qa_data = mx.sym.BlockGrad(qa_data)
        qa_embed_data = mx.sym.Embedding(data=qa_data, input_dim=self.n_question*2+1,
                                         output_dim=self.qa_embed_dim, name='qa_embed')
        slice_qa_embed_data = mx.sym.SliceChannel(qa_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True)

        value_read_content_l = []
        input_embed_l = []
        for i in range(self.seqlen):
            ## Attention
            q = slice_q_embed_data[i]
            correlation_weight = mem.attention(q)

            ## Read Process
            read_content = mem.read(correlation_weight) #Shape (batch_size, memory_state_dim)
            ### save intermedium data
            value_read_content_l.append(read_content)
            input_embed_l.append(q)

            ## Write Process
            qa = slice_qa_embed_data[i]
            new_memory_value = mem.write(correlation_weight, qa)

        all_read_value_content = mx.sym.Concat(*value_read_content_l, num_args=self.seqlen, dim=0)

        input_embed_content = mx.sym.Concat(*input_embed_l, num_args=self.seqlen, dim=0)
        input_embed_content = mx.sym.FullyConnected(data=input_embed_content, num_hidden=50, name="input_embed_content")
        input_embed_content = mx.sym.Activation(data=input_embed_content, act_type='tanh', name="input_embed_content_tanh")

        read_content_embed = mx.sym.FullyConnected(data=mx.sym.Concat(all_read_value_content, input_embed_content, num_args=2, dim=1),
                                                   num_hidden=self.final_fc_dim, name="read_content_embed")
        read_content_embed = mx.sym.Activation(data=read_content_embed, act_type='tanh', name="read_content_embed_tanh")

        pred = mx.sym.FullyConnected(data=read_content_embed, num_hidden=1, name="final_fc")

        pred_prob = logistic_regression_mask_output(data=mx.sym.Reshape(pred, shape=(-1, )),
                                                    label=mx.sym.Reshape(data=target, shape=(-1,)),
                                                    ignore_label=-1., name='final_pred')
        return mx.sym.Group([pred_prob, pred, read_content_embed,correlation_weight])
예제 #6
0
파일: model.py 프로젝트: XiangrongXu/ML
class MODEL(nn.Module):
    def __init__(self,
                 n_question,
                 batch_size,
                 q_embed_dim,
                 qa_embed_dim,
                 memory_size,
                 memory_key_state_dim,
                 memory_value_state_dim,
                 final_fc_dim,
                 student_num=None):
        self.n_question = n_question
        self.batch_size = batch_size
        self.q_embed_dim = q_embed_dim
        self.qa_embed_dim = qa_embed_dim
        self.memory_size = memory_size
        self.memory_key_state_dim = memory_key_state_dim
        self.memory_value_state_dim = memory_value_state_dim
        self.final_fc_dim = final_fc_dim
        self.student_num = student_num

        self.input_embed_linear = nn.Linear(self.q_embed_dim,
                                            self.final_fc_dim,
                                            bias=True)
        self.read_embed_linear = nn.Linear(self.memory_value_state_dim +
                                           self.final_fc_dim,
                                           self.final_fc_dim,
                                           bias=True)
        self.predict_linear = nn.Linear(self.final_fc_dim, 1, bias=True)
        self.init_memory_key = nn.Parameter(
            torch.randn(self.memory_size, self.memory_key_state_dim))
        nn.init.kaiming_normal(self.init_memory_key)
        self.init_memory_value = nn.Parameter(
            torch.randn(self.memory_size, self.memory_value_state_dim))
        nn.init.kaiming_normal(self.init_memory_value)

        self.mem = DKVMN(self.memory_size, self.memory_key_state_dim,
                         self.memory_value_state_dim, self.init_memory_key)
        memory_value = nn.Parameter(
            torch.cat([
                self.init_memory_value.unsqueeze(0) for _ in range(batch_size)
            ], 0).data)
        self.mem.set_memory_value(memory_value)

        self.q_embed = nn.Embedding(self.n_question + 1,
                                    self.q_embed_dim,
                                    padding_idx=0)
        self.qa_embed = nn.Embedding(self.n_question * 2 + 1,
                                     self.qa_embed_dim,
                                     padding_idx=0)

    def init_params(self):
        nn.init.kaiming_normal(self.predict_linear.weight)
        nn.init.kaiming_normal(self.read_embed_linear.weight)
        nn.init.constant(self.read_embed_linear.bias, 0)
        nn.init.constant(self.predict_linear.bias, 0)

    def init_embeddings(self):
        nn.init.kaiming_normal(self.q_embed.weight)
        nn.init.kaiming_normal(self.qa_embed.weight)

    def forward(self, q_data, qa_data, target, student_id=None):
        batch_size = q_data.shpae[0]
        seqlen = q_data.shape[1]
        q_embed_data = self.q_embed(q_data)
        qa_embed_data = self.qa_embed(qa_data)

        memory_value = nn.Parameter(
            torch.cat([
                self.init_memory_value.unsqueeze(0) for _ in range(batch_size)
            ], 0).data)
        self.mem.set_memory_value(memory_value)

        slice_q_data = torch.chunk(q_data, seqlen, 1)
        slice_q_embed_data = torch.chunk(q_embed_data, seqlen, 1)
        slice_qa_embed_data = torch.chunk(qa_embed_data, seqlen, 1)

        value_read_content_1 = []
        input_embed_1 = []
        predict_logs = []
        for i in range(seqlen):
            # attention
            q = slice_q_embed_data[i].squeeze(1)
            correlation_weight = self.mem.attention(q)
            if_memory_write = slice_q_data[i].squeeze(1).ge(1)
            if_memory_write = utils.variable(
                torch.FloatTensor(if_memory_write.data.tolist()), 1)

            # read
            read_content = self.mem.read(correlation_weight)
            value_read_content_1.append(read_content)
            input_embed_1.append(q)

            # write
            qa = slice_qa_embed_data[i].squeeze(1)
            new_memory_value = self.mem.write(correlation_weight, qa)

        all_read_value_content = torch.cat(
            [value_read_content_1[i].squeeze(1) for i in range(seqlen)], 1)
        input_embed_content = torch.cat(
            [input_embed_1[i].squeeze(1) for i in range(seqlen)], 1)

        predict_input = torch.cat(
            [all_read_value_content, input_embed_content], 2)
        read_content_embed = torch.tanh(
            self.read_embed_linear(
                predict_input.reshape(batch_size * seqlen, -1)))

        pred = self.predict_linear(read_content_embed)

        target_1d = target
        mask = target_1d.ge(0)

        pred_1d = pred.reshape(-1, 1)

        filtered_pred = torch.masked_select(pred_1d, mask)
        filtered_target = torch.masked_select(target_1d, mask)
        loss = nn.functional.binary_cross_entropy_with_logits(
            filtered_pred, filtered_target)

        return loss, torch.sigmoid(filtered_pred), filtered_target
예제 #7
0
class Model():
    def __init__(self, args, sess, name='KT'):
        self.args = args
        self.name = name
        self.sess = sess

        # Initialize Memory
        with tf.variable_scope('Memory'):
            init_memory_key = tf.get_variable('key', [self.args.memory_size, self.args.memory_key_state_dim], \
                initializer=tf.truncated_normal_initializer(stddev=0.1))
            init_memory_value = tf.get_variable('value', [self.args.memory_size,self.args.memory_value_state_dim], \
                initializer=tf.truncated_normal_initializer(stddev=0.1))

        # Broadcast memory value tensor to match [batch size, memory size, memory state dim]
        # First expand dim at axis 0 so that makes 'batch size' axis and tile it along 'batch size' axis
        # tf.tile(inputs, multiples) : multiples length must be thes saame as the number of dimensions in input
        # tf.stack takes a list and convert each element to a tensor
        init_memory_value = tf.tile(tf.expand_dims(init_memory_value, 0), tf.stack([self.args.batch_size, 1, 1]))
        print(init_memory_value.get_shape())
                
        self.memory = DKVMN(self.args.memory_size, self.args.memory_key_state_dim, \
                self.args.memory_value_state_dim, init_memory_key=init_memory_key, init_memory_value=init_memory_value, name='DKVMN')

        # Embedding to [batch size, seq_len, memory_state_dim(d_k or d_v)]
        with tf.variable_scope('Embedding'):
            # A
            self.q_embed_mtx = tf.get_variable('q_embed', [self.args.n_questions+1, self.args.memory_key_state_dim],\
                initializer=tf.truncated_normal_initializer(stddev=0.1))
            # B
            self.qa_embed_mtx = tf.get_variable('qa_embed', [2*self.args.n_questions+1, self.args.memory_value_state_dim], initializer=tf.truncated_normal_initializer(stddev=0.1))        

        self.prediction = self.build_network(reuse_flag=False)
        self.build_optimizer()
        #self.create_model()
    
    def build_network(self, reuse_flag):
        print('Building network')

        self.q_data = tf.placeholder(tf.int32, [self.args.batch_size], name='q_data') 
        self.qa_data = tf.placeholder(tf.int32, [self.args.batch_size], name='qa_data')
        self.target = tf.placeholder(tf.float32, [self.args.batch_size], name='target')


        # Embedding to [batch size, seq_len, memory key state dim]
        q_embed_data = tf.nn.embedding_lookup(self.q_embed_mtx, self.q_data)
        # List of [batch size, 1, memory key state dim] with 'seq_len' elements
        #print('Q_embedding shape : %s' % q_embed_data.get_shape())
        #slice_q_embed_data = tf.split(q_embed_data, self.args.seq_len, 1)
        #print(len(slice_q_embed_data), type(slice_q_embed_data), slice_q_embed_data[0].get_shape())
        # Embedding to [batch size, seq_len, memory value state dim]
        qa_embed_data = tf.nn.embedding_lookup(self.qa_embed_mtx, self.qa_data)
        #print('QA_embedding shape: %s' % qa_embed_data.get_shape())
        # List of [batch size, 1, memory value state dim] with 'seq_len' elements
        #slice_qa_embed_data = tf.split(qa_embed_data, self.args.seq_len, 1)
        
        #prediction = list()
        #reuse_flag = False

        # k_t : [batch size, memory key state dim]
        #q = tf.squeeze(slice_q_embed_data[i], 1)
        # Attention, [batch size, memory size]
        self.correlation_weight = self.memory.attention(q_embed_data)
        
        # Read process, [batch size, memory value state dim]
        self.read_content = self.memory.read(self.correlation_weight)
        
        # Write process, [batch size, memory size, memory value state dim]
        # qa : [batch size, memory value state dim]
        #qa = tf.squeeze(slice_qa_embed_data[i], 1)
        # Only last time step value is necessary

        self.new_memory_value = self.memory.write(self.correlation_weight, qa_embed_data, reuse=reuse_flag)

        mastery_level_prior_difficulty = tf.concat([self.read_content, q_embed_data], 1)
        # f_t
        summary_vector = tf.tanh(operations.linear(mastery_level_prior_difficulty, self.args.final_fc_dim, name='Summary_Vector', reuse=reuse_flag))
        # p_t
        pred_logits = operations.linear(summary_vector, 1, name='Prediction', reuse=reuse_flag)

        return pred_logits
        
    def build_optimizer(self):
        print('Building optimizer')
        # 'seq_len' means question sequences
        self.q_data_seq = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len], name='q_data_seq') 
        self.qa_data_seq = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len], name='qa_data_seq')
        self.target_seq = tf.placeholder(tf.float32, [self.args.batch_size, self.args.seq_len], name='target_seq')

        # Embedding to [batch size, seq_len, memory key state dim]
        #q_embed_data = tf.nn.embedding_lookup(q_embed_mtx, self.q_data_seq)
        # List of [batch size, 1, memory key state dim] with 'seq_len' elements
        #print('Q_embedding shape : %s' % q_embed_data.get_shape())
        slice_q_data = tf.split(self.q_data_seq, self.args.seq_len, 1)
        #print(len(slice_q_embed_data), type(slice_q_embed_data), slice_q_embed_data[0].get_shape())
        # Embedding to [batch size, seq_len, memory value state dim]
        #qa_embed_data = tf.nn.embedding_lookup(qa_embed_mtx, self.qa_data_seq)
        #print('QA_embedding shape: %s' % qa_embed_data.get_shape())
        # List of [batch size, 1, memory value state dim] with 'seq_len' elements
        slice_qa_data = tf.split(self.qa_data_seq, self.args.seq_len, 1)
        
        prediction = list()
        reuse_flag = False

        # Logics
        for i in range(self.args.seq_len):
            # To reuse linear vectors

            q = tf.squeeze(slice_q_data[i], 1)
            # Attention, [batch size, memory size]
            qa = tf.squeeze(slice_qa_data[i], 1)
            # Only last time step value is necessary

            pred_logits =  

            prediction.append(pred_logits)

        # 'prediction' : seq_len length list of [batch size ,1], make it [batch size, seq_len] tensor
        # tf.stack convert to [batch size, seq_len, 1]
        self.pred_logits = tf.reshape(tf.stack(prediction, axis=1), [self.args.batch_size, self.args.seq_len]) 

        # Define loss : standard cross entropy loss, need to ignore '-1' label example
        # Make target/label 1-d array
        target_1d = tf.reshape(self.target_seq, [-1])
        pred_logits_1d = tf.reshape(self.pred_logits, [-1])
        index = tf.where(tf.not_equal(target_1d, tf.constant(-1., dtype=tf.float32)))
        # tf.gather(params, indices) : Gather slices from params according to indices
        filtered_target = tf.gather(target_1d, index)
        filtered_logits = tf.gather(pred_logits_1d, index)
        self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=filtered_logits, labels=filtered_target))
        self.pred = tf.sigmoid(self.pred_logits)

        # Optimizer : SGD + MOMENTUM with learning rate decay
        self.global_step = tf.Variable(0, trainable=False)
        self.lr = tf.placeholder(tf.float32, [], name='learning_rate')
#        self.lr_decay = tf.train.exponential_decay(self.args.initial_lr, global_step=global_step, decay_steps=10000, decay_rate=0.667, staircase=True)
#        self.learning_rate = tf.maximum(lr, self.args.lr_lowerbound)
        optimizer = tf.train.MomentumOptimizer(self.lr, self.args.momentum)
        grads, vrbs = zip(*optimizer.compute_gradients(self.loss))
        grad, _ = tf.clip_by_global_norm(grads, self.args.maxgradnorm)
        self.train_op = optimizer.apply_gradients(zip(grad, vrbs), global_step=self.global_step)
#        grad_clip = [(tf.clip_by_value(grad, -self.args.maxgradnorm, self.args.maxgradnorm), var) for grad, var in grads]
        self.tr_vrbs = tf.trainable_variables()
        for i in self.tr_vrbs:
            print(i.name)

        self.saver = tf.train.Saver()

    def create_model(self):
        # 'seq_len' means question sequences
        self.q_data_seq = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len], name='q_data_seq') 
        self.qa_data_seq = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len], name='qa_data')
        self.target_seq = tf.placeholder(tf.float32, [self.args.batch_size, self.args.seq_len], name='target')
          
        '''
        # Initialize Memory
        with tf.variable_scope('Memory'):
            init_memory_key = tf.get_variable('key', [self.args.memory_size, self.args.memory_key_state_dim], \
                initializer=tf.truncated_normal_initializer(stddev=0.1))
            init_memory_value = tf.get_variable('value', [self.args.memory_size,self.args.memory_value_state_dim], \
                initializer=tf.truncated_normal_initializer(stddev=0.1))
        # Broadcast memory value tensor to match [batch size, memory size, memory state dim]
        # First expand dim at axis 0 so that makes 'batch size' axis and tile it along 'batch size' axis
        # tf.tile(inputs, multiples) : multiples length must be thes saame as the number of dimensions in input
        # tf.stack takes a list and convert each element to a tensor
        init_memory_value = tf.tile(tf.expand_dims(init_memory_value, 0), tf.stack([self.args.batch_size, 1, 1]))
        print(init_memory_value.get_shape())
                
        self.memory = DKVMN(self.args.memory_size, self.args.memory_key_state_dim, \
                self.args.memory_value_state_dim, init_memory_key=init_memory_key, init_memory_value=init_memory_value, name='DKVMN')

        # Embedding to [batch size, seq_len, memory_state_dim(d_k or d_v)]
        with tf.variable_scope('Embedding'):
            # A
            q_embed_mtx = tf.get_variable('q_embed', [self.args.n_questions+1, self.args.memory_key_state_dim],\
                initializer=tf.truncated_normal_initializer(stddev=0.1))
            # B
            qa_embed_mtx = tf.get_variable('qa_embed', [2*self.args.n_questions+1, self.args.memory_value_state_dim], initializer=tf.truncated_normal_initializer(stddev=0.1))        

        '''
        # Embedding to [batch size, seq_len, memory key state dim]
        q_embed_data = tf.nn.embedding_lookup(self.q_embed_mtx, self.q_data_seq)
        # List of [batch size, 1, memory key state dim] with 'seq_len' elements
        #print('Q_embedding shape : %s' % q_embed_data.get_shape())
        slice_q_embed_data = tf.split(q_embed_data, self.args.seq_len, 1)
        #print(len(slice_q_embed_data), type(slice_q_embed_data), slice_q_embed_data[0].get_shape())
        # Embedding to [batch size, seq_len, memory value state dim]
        qa_embed_data = tf.nn.embedding_lookup(self.qa_embed_mtx, self.qa_data_seq)
        #print('QA_embedding shape: %s' % qa_embed_data.get_shape())
        # List of [batch size, 1, memory value state dim] with 'seq_len' elements
        slice_qa_embed_data = tf.split(qa_embed_data, self.args.seq_len, 1)
        
        prediction = list()
        reuse_flag = False

        # Logics
        for i in range(self.args.seq_len):
            # To reuse linear vectors
            if i != 0:
                reuse_flag = True
            # k_t : [batch size, memory key state dim]
            q = tf.squeeze(slice_q_embed_data[i], 1)
            # Attention, [batch size, memory size]
            self.correlation_weight = self.memory.attention(q)
            
            # Read process, [batch size, memory value state dim]
            self.read_content = self.memory.read(self.correlation_weight)
            
            # Write process, [batch size, memory size, memory value state dim]
            # qa : [batch size, memory value state dim]
            qa = tf.squeeze(slice_qa_embed_data[i], 1)
            # Only last time step value is necessary
            self.new_memory_value = self.memory.write(self.correlation_weight, qa, reuse=reuse_flag)

            mastery_level_prior_difficulty = tf.concat([self.read_content, q], 1)
            # f_t
            summary_vector = tf.tanh(operations.linear(mastery_level_prior_difficulty, self.args.final_fc_dim, name='Summary_Vector', reuse=reuse_flag))
            # p_t
            pred_logits = operations.linear(summary_vector, 1, name='Prediction', reuse=reuse_flag)

            prediction.append(pred_logits)

        # 'prediction' : seq_len length list of [batch size ,1], make it [batch size, seq_len] tensor
        # tf.stack convert to [batch size, seq_len, 1]
        self.pred_logits = tf.reshape(tf.stack(prediction, axis=1), [self.args.batch_size, self.args.seq_len]) 

        # Define loss : standard cross entropy loss, need to ignore '-1' label example
        # Make target/label 1-d array
        target_1d = tf.reshape(self.target_seq, [-1])
        pred_logits_1d = tf.reshape(self.pred_logits, [-1])
        index = tf.where(tf.not_equal(target_1d, tf.constant(-1., dtype=tf.float32)))
        # tf.gather(params, indices) : Gather slices from params according to indices
        filtered_target = tf.gather(target_1d, index)
        filtered_logits = tf.gather(pred_logits_1d, index)
        self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=filtered_logits, labels=filtered_target))
        self.pred = tf.sigmoid(self.pred_logits)

        # Optimizer : SGD + MOMENTUM with learning rate decay
        self.global_step = tf.Variable(0, trainable=False)
        self.lr = tf.placeholder(tf.float32, [], name='learning_rate')
#        self.lr_decay = tf.train.exponential_decay(self.args.initial_lr, global_step=global_step, decay_steps=10000, decay_rate=0.667, staircase=True)
#        self.learning_rate = tf.maximum(lr, self.args.lr_lowerbound)
        optimizer = tf.train.MomentumOptimizer(self.lr, self.args.momentum)
        grads, vrbs = zip(*optimizer.compute_gradients(self.loss))
        grad, _ = tf.clip_by_global_norm(grads, self.args.maxgradnorm)
        self.train_op = optimizer.apply_gradients(zip(grad, vrbs), global_step=self.global_step)
#        grad_clip = [(tf.clip_by_value(grad, -self.args.maxgradnorm, self.args.maxgradnorm), var) for grad, var in grads]
        self.tr_vrbs = tf.trainable_variables()
        for i in self.tr_vrbs:
            print(i.name)

        self.saver = tf.train.Saver()


    def train(self, train_q_data, train_qa_data, valid_q_data, valid_qa_data):
        # q_data, qa_data : [samples, seq_len]
        shuffle_index = np.random.permutation(train_q_data.shape[0])
        q_data_shuffled = train_q_data[shuffle_index]
        qa_data_shuffled = train_qa_data[shuffle_index]

        training_step = train_q_data.shape[0] // self.args.batch_size
        self.sess.run(tf.global_variables_initializer())
        
        if self.args.show:
            from utils import ProgressBar
            bar = ProgressBar(label, max=training_step)

        self.train_count = 0
        if self.args.init_from:
            if self.load():
                print('Checkpoint_loaded')
            else:
                print('No checkpoint')
        else:
            if os.path.exists(os.path.join(self.args.checkpoint_dir, self.model_dir)):
                try:
                    shutil.rmtree(os.path.join(self.args.checkpoint_dir, self.model_dir))
                    shutil.rmtree(os.path.join(self.args.log_dir, self.model_dir+'.csv'))
                except(FileNotFoundError, IOError) as e:
                    print('[Delete Error] %s - %s' % (e.filename, e.strerror))
        
        best_valid_auc = 0

        # Training
        for epoch in range(0, self.args.num_epochs):
            if self.args.show:
                bar.next()

            pred_list = list()
            target_list = list()        
            epoch_loss = 0
            learning_rate = tf.train.exponential_decay(self.args.initial_lr, global_step=self.global_step, decay_steps=self.args.anneal_interval*training_step, decay_rate=0.667, staircase=True)

            #print('Epoch %d starts with learning rate : %3.5f' % (epoch+1, self.sess.run(learning_rate)))
            for steps in range(training_step):
                # [batch size, seq_len]
                q_batch_seq = q_data_shuffled[steps*self.args.batch_size:(steps+1)*self.args.batch_size, :]
                qa_batch_seq = qa_data_shuffled[steps*self.args.batch_size:(steps+1)*self.args.batch_size, :]
    
                # qa : exercise index + answer(0 or 1)*exercies_number
                # right : 1, wrong : 0, padding : -1
                target = qa_batch_seq[:,:]
                # Make integer type to calculate target
                target = target.astype(np.int)
                target_batch = (target - 1) // self.args.n_questions  
                target_batch = target_batch.astype(np.float)

                feed_dict = {self.q_data_seq:q_batch_seq, self.qa_data_seq:qa_batch_seq, self.target_seq:target_batch, self.lr:self.args.initial_lr}
                #self.lr:self.sess.run(learning_rate)
                loss_, pred_, _, = self.sess.run([self.loss, self.pred, self.train_op], feed_dict=feed_dict)
                # Get right answer index
                # Make [batch size * seq_len, 1]
                right_target = np.asarray(target_batch).reshape(-1,1)
                right_pred = np.asarray(pred_).reshape(-1,1)
                # np.flatnonzero returns indices which is nonzero, convert it list 
                right_index = np.flatnonzero(right_target != -1.).tolist()
                # Number of 'training_step' elements list with [batch size * seq_len, ]
                pred_list.append(right_pred[right_index])
                target_list.append(right_target[right_index])

                epoch_loss += loss_
                #print('Epoch %d/%d, steps %d/%d, loss : %3.5f' % (epoch+1, self.args.num_epochs, steps+1, training_step, loss_))
                

            if self.args.show:
                bar.finish()        
            
            all_pred = np.concatenate(pred_list, axis=0)
            all_target = np.concatenate(target_list, axis=0)

            # Compute metrics
            self.auc = metrics.roc_auc_score(all_target, all_pred)
            # Extract elements with boolean index
            # Make '1' for elements higher than 0.5
            # Make '0' for elements lower than 0.5
            all_pred[all_pred > 0.5] = 1
            all_pred[all_pred <= 0.5] = 0
            self.accuracy = metrics.accuracy_score(all_target, all_pred)

            epoch_loss = epoch_loss / training_step    
            print('Epoch %d/%d, loss : %3.5f, auc : %3.5f, accuracy : %3.5f' % (epoch+1, self.args.num_epochs, epoch_loss, self.auc, self.accuracy))
            self.write_log(epoch=epoch+1, auc=self.auc, accuracy=self.accuracy, loss=epoch_loss, name='training_')

            valid_steps = valid_q_data.shape[0] // self.args.batch_size
            valid_pred_list = list()
            valid_target_list = list()
            for s in range(valid_steps):
                # Validation
                valid_q = valid_q_data[s*self.args.batch_size:(s+1)*self.args.batch_size, :]
                valid_qa = valid_qa_data[s*self.args.batch_size:(s+1)*self.args.batch_size, :]
                # right : 1, wrong : 0, padding : -1
                valid_target = (valid_qa - 1) // self.args.n_questions
                valid_feed_dict = {self.q_data_seq : valid_q, self.qa_data_seq : valid_qa, self.target_seq : valid_target}
                valid_loss, valid_pred = self.sess.run([self.loss, self.pred], feed_dict=valid_feed_dict)
                # Same with training set
                valid_right_target = np.asarray(valid_target).reshape(-1,)
                valid_right_pred = np.asarray(valid_pred).reshape(-1,)
                valid_right_index = np.flatnonzero(valid_right_target != -1).tolist()    
                valid_target_list.append(valid_right_target[valid_right_index])
                valid_pred_list.append(valid_right_pred[valid_right_index])
            
            all_valid_pred = np.concatenate(valid_pred_list, axis=0)
            all_valid_target = np.concatenate(valid_target_list, axis=0)

            valid_auc = metrics.roc_auc_score(all_valid_target, all_valid_pred)
             # For validation accuracy
            all_valid_pred[all_valid_pred > 0.5] = 1
            all_valid_pred[all_valid_pred <= 0.5] = 0
            valid_accuracy = metrics.accuracy_score(all_valid_target, all_valid_pred)
            print('Epoch %d/%d, valid auc : %3.5f, valid accuracy : %3.5f' %(epoch+1, self.args.num_epochs, valid_auc, valid_accuracy))
            # Valid log
            self.write_log(epoch=epoch+1, auc=valid_auc, accuracy=valid_accuracy, loss=valid_loss, name='valid_')
            if valid_auc > best_valid_auc:
                print('%3.4f to %3.4f' % (best_valid_auc, valid_auc))
                best_valid_auc = valid_auc
                best_epoch = epoch + 1
                self.save(best_epoch)

        return best_epoch    
            
    def test(self, test_q, test_qa):
        steps = test_q.shape[0] // self.args.batch_size
        self.sess.run(tf.global_variables_initializer())
        if self.load():
            print('CKPT Loaded')
        else:
            raise Exception('CKPT need')

        pred_list = list()
        target_list = list()

        for s in range(steps):
            test_q_batch = test_q[s*self.args.batch_size:(s+1)*self.args.batch_size, :]
            test_qa_batch = test_qa[s*self.args.batch_size:(s+1)*self.args.batch_size, :]
            target = test_qa_batch[:,:]
            target = target.astype(np.int)
            target_batch = (target - 1) // self.args.n_questions  
            target_batch = target_batch.astype(np.float)
            feed_dict = {self.q_data_seq:test_q_batch, self.qa_data_seq:test_qa_batch, self.target_seq:target_batch}
            loss_, pred_ = self.sess.run([self.loss, self.pred], feed_dict=feed_dict)
            # Get right answer index
            # Make [batch size * seq_len, 1]
            right_target = np.asarray(target_batch).reshape(-1,1)
            right_pred = np.asarray(pred_).reshape(-1,1)
            # np.flatnonzero returns indices which is nonzero, convert it list 
            right_index = np.flatnonzero(right_target != -1.).tolist()
            # Number of 'training_step' elements list with [batch size * seq_len, ]
            pred_list.append(right_pred[right_index])
            target_list.append(right_target[right_index])

        all_pred = np.concatenate(pred_list, axis=0)
        all_target = np.concatenate(target_list, axis=0)

        # Compute metrics
        all_pred[all_pred > 0.5] = 1
        all_pred[all_pred <= 0.5] = 0
        self.test_auc = metrics.roc_auc_score(all_target, all_pred)
        # Extract elements with boolean index
        # Make '1' for elements higher than 0.5
        # Make '0' for elements lower than 0.5

        self.test_accuracy = metrics.accuracy_score(all_target, all_pred)

        print('Test auc : %3.4f, Test accuracy : %3.4f' % (self.test_auc, self.test_accuracy))


    @property
    def model_dir(self):
        return '{}_{}batch_{}epochs'.format(self.args.dataset, self.args.batch_size, self.args.num_epochs)

    def load(self):
        checkpoint_dir = os.path.join(self.args.checkpoint_dir, self.model_dir)
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.train_count = int(ckpt_name.split('-')[-1])
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            return True
        else:
            return False

    def save(self, global_step):
        model_name = 'DKVMN'
        checkpoint_dir = os.path.join(self.args.checkpoint_dir, self.model_dir)
        if not os.path.exists(checkpoint_dir):
            os.mkdir(checkpoint_dir)
        self.saver.save(self.sess, os.path.join(checkpoint_dir, model_name), global_step=global_step)
        print('Save checkpoint at %d' % (global_step+1))

    # Log file
    def write_log(self, auc, accuracy, loss, epoch, name='training_'):
        log_path = os.path.join(self.args.log_dir, name+self.model_dir+'.csv')
        if not os.path.exists(log_path):
            self.log_file = open(log_path, 'w')
            self.log_file.write('Epoch\tAuc\tAccuracy\tloss\n')
        else:
            self.log_file = open(log_path, 'a')    
        
        self.log_file.write(str(epoch) + '\t' + str(auc) + '\t' + str(accuracy) + '\t' + str(loss) + '\n')
        self.log_file.flush()    
예제 #8
0
    def sym_gen(self):
        ### TODO input variable 'q_data'
        q_data = mx.sym.Variable('q_data', shape=(self.seqlen, self.batch_size))  # (seqlen, batch_size)
        ### TODO input variable 'qa_data'
        qa_data = mx.sym.Variable('qa_data', shape=(self.seqlen, self.batch_size))  # (seqlen, batch_size)
        ### TODO input variable 'target'
        target = mx.sym.Variable('target', shape=(self.seqlen, self.batch_size))  # (seqlen, batch_size)

        ### Initialize Memory
        init_memory_key = mx.sym.Variable('init_memory_key_weight')
        init_memory_value = mx.sym.Variable('init_memory_value',
                                            shape=(self.memory_size, self.memory_value_state_dim),
                                            init=mx.init.Normal(0.1))  # (self.memory_size, self.memory_value_state_dim)
        init_memory_value = mx.sym.broadcast_to(mx.sym.expand_dims(init_memory_value, axis=0),
                                                shape=(self.batch_size, self.memory_size, self.memory_value_state_dim))

        mem = DKVMN(memory_size=self.memory_size,
                    memory_key_state_dim=self.memory_key_state_dim,
                    memory_value_state_dim=self.memory_value_state_dim,
                    init_memory_key=init_memory_key,
                    init_memory_value=init_memory_value,
                    name="DKVMN")

        ### embedding
        q_data = mx.sym.BlockGrad(q_data)
        q_embed_data = mx.sym.Embedding(data=q_data, input_dim=self.n_question + 1,
                                        output_dim=self.q_embed_dim, name='q_embed')
        slice_q_embed_data = mx.sym.SliceChannel(q_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True)
        qa_data = mx.sym.BlockGrad(qa_data)
        qa_embed_data = mx.sym.Embedding(data=qa_data, input_dim=self.n_question * 2 + 1,
                                         output_dim=self.qa_embed_dim, name='qa_embed')
        slice_qa_embed_data = mx.sym.SliceChannel(qa_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True)

        # memory block
        value_read_content_l = []  # (seqlen, batch_size, memory_state_dim)
        input_embed_l = []  # (seqlen, batch_size, q_embed_dim)
        for i in range(self.seqlen):
            ## Attention
            q = slice_q_embed_data[i]
            correlation_weight = mem.attention(q)
            ## Read Process
            read_content = mem.read(correlation_weight)  # Shape (batch_size, memory_state_dim)
            ### save intermedium data
            value_read_content_l.append(read_content)
            input_embed_l.append(q)
            ## Write Process
            qa = slice_qa_embed_data[i]
            new_memory_value = mem.write(correlation_weight, qa)

        # (batch_size * seqlen, memory_state_dim)
        all_read_value_content = mx.sym.Concat(*value_read_content_l, num_args=self.seqlen, dim=0)
        all_read_value_content = mx.sym.reshape(data=all_read_value_content, shape=(self.seqlen, self.batch_size, -1))
        # (batch_size * seqlen, q_embed_dim)
        input_embed_content = mx.sym.Concat(*input_embed_l, num_args=self.seqlen, dim=0)
        input_embed_content = mx.sym.reshape(data=input_embed_content, shape=(self.seqlen, self.batch_size, -1))

        combined_data = mx.sym.Concat(all_read_value_content, input_embed_content, num_args=2, dim=2)

        lstm = mx.gluon.rnn.LSTM(50, num_layers=1, layout='TNC', prefix='LSTM')
        # (seqlen, batch_size, final_fc_dim)
        output = lstm(combined_data)
        output = mx.sym.reshape(data=output, shape=(-3, 0))

        # output = mx.sym.FullyConnected(data=output, num_hidden=self.final_fc_dim, name='fc')
        # output = mx.sym.Activation(data=output, act_type='tanh', name='fc_tanh')

        pred = mx.sym.FullyConnected(data=output, num_hidden=1, name="final_fc")
        pred_prob = logistic_regression_mask_output(data=mx.sym.Reshape(pred, shape=(-1,)),
                                                    label=mx.sym.Reshape(data=target, shape=(-1,)),
                                                    ignore_label=-1., name='final_pred')
        mx.sym.softmax_cross_entropy()

        net = mx.sym.Group([pred_prob])
        return net
예제 #9
0
파일: model.py 프로젝트: MIracleyin/DKVMN
class MODEL(nn.Module):
    def __init__(self,
                 n_question,
                 batch_size,
                 q_embed_dim,
                 qa_embed_dim,
                 memory_size,
                 memory_key_state_dim,
                 memory_value_state_dim,
                 final_fc_dim,
                 student_num=None):
        super(MODEL, self).__init__()
        self.n_question = n_question
        self.batch_size = batch_size
        self.q_embed_dim = q_embed_dim
        self.qa_embed_dim = qa_embed_dim
        self.memory_size = memory_size
        self.memory_key_state_dim = memory_key_state_dim
        self.memory_value_state_dim = memory_value_state_dim
        self.final_fc_dim = final_fc_dim
        self.student_num = student_num

        self.input_embed_linear = nn.Linear(self.q_embed_dim,
                                            self.final_fc_dim,
                                            bias=True)
        self.read_embed_linear = nn.Linear(self.memory_value_state_dim +
                                           self.final_fc_dim,
                                           self.final_fc_dim,
                                           bias=True)
        self.predict_linear = nn.Linear(self.final_fc_dim, 1, bias=True)
        self.init_memory_key = nn.Parameter(
            torch.randn(self.memory_size, self.memory_key_state_dim))  # random
        nn.init.kaiming_normal_(self.init_memory_key)
        self.init_memory_value = nn.Parameter(
            torch.randn(self.memory_size, self.memory_value_state_dim))  #
        nn.init.kaiming_normal_(self.init_memory_value)

        self.mem = DKVMN(
            memory_size=self.memory_size,
            memory_key_state_dim=self.
            memory_key_state_dim,  # memory_key 初始化后不变化
            memory_value_state_dim=self.memory_value_state_dim,
            init_memory_key=self.init_memory_key)  # 不断 write 更新memo value

        memory_value = nn.Parameter(
            torch.cat([
                self.init_memory_value.unsqueeze(0) for _ in range(batch_size)
            ], 0).data)
        self.mem.init_value_memory(memory_value)

        self.q_embed = nn.Embedding(self.n_question + 1,
                                    self.q_embed_dim,
                                    padding_idx=0)
        self.qa_embed = nn.Embedding(2 * self.n_question + 1,
                                     self.qa_embed_dim,
                                     padding_idx=0)

    def init_params(self):
        nn.init.kaiming_normal_(self.predict_linear.weight)
        nn.init.kaiming_normal_(self.read_embed_linear.weight)
        nn.init.constant_(self.read_embed_linear.bias, 0)
        nn.init.constant_(self.predict_linear.bias, 0)
        # nn.init.constant_(self.input_embed_linear.bias, 0)
        # nn.init.normal(self.input_embed_linear.weight, std=0.02)

    def init_embeddings(self):
        nn.init.kaiming_normal_(self.q_embed.weight)
        nn.init.kaiming_normal_(self.qa_embed.weight)

    def forward(self, q_data, qa_data, target, student_id=None):
        batch_size = q_data.shape[0]
        seqlen = q_data.shape[1]
        q_embed_data = self.q_embed(q_data)
        qa_embed_data = self.qa_embed(qa_data)

        memory_value = nn.Parameter(
            torch.cat([
                self.init_memory_value.unsqueeze(0) for _ in range(batch_size)
            ], 0).data)  # memory: 32, 20, 200
        self.mem.init_value_memory(memory_value)

        slice_q_data = torch.chunk(q_data, seqlen, 1)
        slice_q_embed_data = torch.chunk(q_embed_data, seqlen, 1)
        slice_qa_embed_data = torch.chunk(qa_embed_data, seqlen, 1)

        value_read_content_l = []
        input_embed_l = []
        # new_memory_value_l = []
        predict_logs = []
        for i in range(seqlen):
            ## Attention
            q = slice_q_embed_data[i].squeeze(1)
            correlation_weight = self.mem.attention(q)
            if_memory_write = slice_q_data[i].squeeze(1).ge(1)  # q
            if_memory_write = utils.varible(
                torch.FloatTensor(if_memory_write.data.tolist()), 1)

            ## Read Process
            read_content = self.mem.read(correlation_weight)
            value_read_content_l.append(read_content)
            input_embed_l.append(q)
            ## Write Process
            qa = slice_qa_embed_data[i].squeeze(1)
            new_memory_value = self.mem.write(
                correlation_weight, qa,
                if_memory_write)  # 直接把200个seq最后一次更新的memory_value作为知识水平向量输出
            # new_memory_value_l.append(new_memory_value.squ)
            # read_content_embed = torch.tanh(self.read_embed_linear(torch.cat([read_content, q], 1)))
            # pred = self.predict_linear(read_content_embed)
            # predict_logs.append(pred)

        all_read_value_content = torch.cat(
            [value_read_content_l[i].unsqueeze(1) for i in range(seqlen)],
            1)  # 将每一个习题的内容拼接起来
        input_embed_content = torch.cat(
            [input_embed_l[i].unsqueeze(1) for i in range(seqlen)], 1)
        # input_embed_content = input_embed_content.view(batch_size * seqlen, -1)
        # input_embed_content = torch.tanh(self.input_embed_linear(input_embed_content))
        # input_embed_content = input_embed_content.view(batch_size, seqlen, -1)

        predict_input = torch.cat(
            [all_read_value_content, input_embed_content], 2)
        read_content_embed = torch.tanh(
            self.read_embed_linear(predict_input.view(batch_size * seqlen,
                                                      -1)))

        pred = self.predict_linear(read_content_embed)
        # predicts = torch.cat([predict_logs[i] for i in range(seqlen)], 1)
        target_1d = target  # [batch_size * seq_len, 1]
        mask = target_1d.ge(0)  # [batch_size * seq_len, 1]
        # pred_1d = predicts.view(-1, 1)           # [batch_size * seq_len, 1]
        pred_1d = pred.view(-1, 1)  # [batch_size * seq_len, 1]

        filtered_pred = torch.masked_select(pred_1d, mask)
        filtered_target = torch.masked_select(target_1d, mask)
        # memory_value = torch.masked_select(new_memory_value.view(batch_size * seqlen, -1), mask)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            filtered_pred, filtered_target)

        return loss, torch.sigmoid(
            filtered_pred), filtered_target, new_memory_value
예제 #10
0
class DeepIRTModel(object):
    def __init__(self, args, sess, name="KT"):
        self.args = args
        self.sess = sess
        self.name = name
        self.create_model()

    def create_model(self):
        self._create_placeholder()
        self._influence()
        self._create_loss()
        self._create_optimizer()
        self._add_summary()

    def _create_placeholder(self):
        logger.info("Initializing Placeholder")
        self.q_data = tf.placeholder(tf.int32,
                                     [self.args.batch_size, self.args.seq_len],
                                     name='q_data')
        self.qa_data = tf.placeholder(
            tf.int32, [self.args.batch_size, self.args.seq_len],
            name='qa_data')
        self.label = tf.placeholder(tf.float32,
                                    [self.args.batch_size, self.args.seq_len],
                                    name='label')

    def _influence(self):
        # Initialize Memory
        logger.info("Initializing Key and Value Memory")
        with tf.variable_scope("Memory"):
            init_key_memory = tf.get_variable(
                'key_memory_matrix',
                [self.args.memory_size, self.args.key_memory_state_dim],
                initializer=tf.truncated_normal_initializer(stddev=0.1))
            init_value_memory = tf.get_variable(
                'value_memory_matrix',
                [self.args.memory_size, self.args.value_memory_state_dim],
                initializer=tf.truncated_normal_initializer(stddev=0.1))

        # Boardcast value-memory matrix to Shape (batch_size, memory_size, memory_value_state_dim)
        init_value_memory = tf.tile(  # tile the number of value-memory by the number of batch
            tf.expand_dims(init_value_memory, 0),  # make the batch-axis
            tf.stack([self.args.batch_size, 1, 1]))
        logger.debug("Shape of init_value_memory = {}".format(
            init_value_memory.get_shape()))
        logger.debug("Shape of init_key_memory = {}".format(
            init_key_memory.get_shape()))

        # Initialize DKVMN
        self.memory = DKVMN(
            memory_size=self.args.memory_size,
            key_memory_state_dim=self.args.key_memory_state_dim,
            value_memory_state_dim=self.args.value_memory_state_dim,
            init_key_memory=init_key_memory,
            init_value_memory=init_value_memory,
            name="DKVMN")

        # Initialize Embedding
        logger.info("Initializing Q and QA Embedding")
        with tf.variable_scope('Embedding'):
            q_embed_matrix = tf.get_variable(
                'q_embed',
                [self.args.n_questions + 1, self.args.key_memory_state_dim],
                initializer=tf.truncated_normal_initializer(stddev=0.1))
            qa_embed_matrix = tf.get_variable(
                'qa_embed', [
                    2 * self.args.n_questions + 1,
                    self.args.value_memory_state_dim
                ],
                initializer=tf.truncated_normal_initializer(stddev=0.1))

        # Embedding to Shape (batch size, seq_len, memory_state_dim(d_k or d_v))
        logger.info("Initializing Embedding Lookup")
        q_embed_data = tf.nn.embedding_lookup(q_embed_matrix, self.q_data)
        qa_embed_data = tf.nn.embedding_lookup(qa_embed_matrix, self.qa_data)

        logger.debug("Shape of q_embed_data: {}".format(
            q_embed_data.get_shape()))
        logger.debug("Shape of qa_embed_data: {}".format(
            qa_embed_data.get_shape()))

        sliced_q_embed_data = tf.split(value=q_embed_data,
                                       num_or_size_splits=self.args.seq_len,
                                       axis=1)
        sliced_qa_embed_data = tf.split(value=qa_embed_data,
                                        num_or_size_splits=self.args.seq_len,
                                        axis=1)
        logger.debug("Shape of sliced_q_embed_data[0]: {}".format(
            sliced_q_embed_data[0].get_shape()))
        logger.debug("Shape of sliced_qa_embed_data[0]: {}".format(
            sliced_qa_embed_data[0].get_shape()))

        pred_z_values = list()
        student_abilities = list()
        question_difficulties = list()
        reuse_flag = False
        logger.info("Initializing Influence Procedure")
        for i in range(self.args.seq_len):
            # To reuse linear vectors
            if i != 0:
                reuse_flag = True

            # Get the query and content vector
            q = tf.squeeze(sliced_q_embed_data[i], 1)
            qa = tf.squeeze(sliced_qa_embed_data[i], 1)
            logger.debug("qeury vector q: {}".format(q))
            logger.debug("content vector qa: {}".format(qa))

            # Attention, correlation_weight: Shape (batch_size, memory_size)
            self.correlation_weight = self.memory.attention(
                embedded_query_vector=q)
            logger.debug("correlation_weight: {}".format(
                self.correlation_weight))

            # Read process, read_content: (batch_size, value_memory_state_dim)
            self.read_content = self.memory.read(
                correlation_weight=self.correlation_weight)
            logger.debug("read_content: {}".format(self.read_content))

            # Write process, new_memory_value: Shape (batch_size, memory_size, value_memory_state_dim)
            self.new_memory_value = self.memory.write(self.correlation_weight,
                                                      qa,
                                                      reuse=reuse_flag)
            logger.debug("new_memory_value: {}".format(self.new_memory_value))

            # Build the feature vector -- summary_vector
            mastery_level_prior_difficulty = tf.concat([self.read_content, q],
                                                       1)

            self.summary_vector = layers.fully_connected(
                inputs=mastery_level_prior_difficulty,
                num_outputs=self.args.summary_vector_output_dim,
                scope='SummaryOperation',
                reuse=reuse_flag,
                activation_fn=tf.nn.tanh)
            logger.debug("summary_vector: {}".format(self.summary_vector))

            # Calculate the student ability level from summary vector
            student_ability = layers.fully_connected(
                inputs=self.summary_vector,
                num_outputs=1,
                scope='StudentAbilityOutputLayer',
                reuse=reuse_flag,
                activation_fn=None)

            # Calculate the question difficulty level from the question embedding
            question_difficulty = layers.fully_connected(
                inputs=q,
                num_outputs=1,
                scope='QuestionDifficultyOutputLayer',
                reuse=reuse_flag,
                activation_fn=tf.nn.tanh)

            # Prediction
            pred_z_value = 3.0 * student_ability - question_difficulty
            pred_z_values.append(pred_z_value)
            student_abilities.append(student_ability)
            question_difficulties.append(question_difficulty)

        self.pred_z_values = tf.reshape(tf.stack(
            pred_z_values, axis=1), [self.args.batch_size, self.args.seq_len])
        self.student_abilities = tf.reshape(
            tf.stack(student_abilities, axis=1),
            [self.args.batch_size, self.args.seq_len])
        self.question_difficulties = tf.reshape(
            tf.stack(question_difficulties, axis=1),
            [self.args.batch_size, self.args.seq_len])
        logger.debug("Shape of pred_z_values: {}".format(self.pred_z_values))
        logger.debug("Shape of student_abilities: {}".format(
            self.student_abilities))
        logger.debug("Shape of question_difficulties: {}".format(
            self.question_difficulties))

    def _create_loss(self):
        logger.info("Initializing Loss Function")

        # convert into 1D
        label_1d = tf.reshape(self.label, [-1])
        pred_z_values_1d = tf.reshape(self.pred_z_values, [-1])
        student_abilities_1d = tf.reshape(self.student_abilities, [-1])
        question_difficulties_1d = tf.reshape(self.question_difficulties, [-1])

        # find the label index that is not masking
        index = tf.where(
            tf.not_equal(label_1d, tf.constant(-1., dtype=tf.float32)))

        # masking
        filtered_label = tf.gather(label_1d, index)
        filtered_z_values = tf.gather(pred_z_values_1d, index)
        filtered_student_abilities = tf.gather(student_abilities_1d, index)
        filtered_question_difficulties = tf.gather(question_difficulties_1d,
                                                   index)
        logger.debug("Shape of filtered_label: {}".format(filtered_label))
        logger.debug(
            "Shape of filtered_z_values: {}".format(filtered_z_values))
        logger.debug("Shape of filtered_student_abilities: {}".format(
            filtered_student_abilities))
        logger.debug("Shape of filtered_question_difficulties: {}".format(
            filtered_question_difficulties))

        if self.args.use_ogive_model:
            # make prediction using normal ogive model
            dist = tfd.Normal(loc=0.0, scale=1.0)
            self.pred = dist.cdf(pred_z_values_1d)
            filtered_pred = dist.cdf(filtered_z_values)
        else:
            self.pred = tf.math.sigmoid(pred_z_values_1d)
            filtered_pred = tf.math.sigmoid(filtered_z_values)

        # convert the prediction probability to logit, i.e., log(p/(1-p))
        epsilon = 1e-6
        clipped_filtered_pred = tf.clip_by_value(filtered_pred, epsilon,
                                                 1. - epsilon)
        filtered_logits = tf.log(clipped_filtered_pred /
                                 (1 - clipped_filtered_pred))

        # cross entropy loss
        cross_entropy = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=filtered_logits,
                                                    labels=filtered_label))

        self.loss = cross_entropy

    def _create_optimizer(self):
        with tf.variable_scope('Optimizer'):
            self.optimizer = tf.train.AdamOptimizer(
                learning_rate=self.args.learning_rate)
            gvs = self.optimizer.compute_gradients(self.loss)
            clipped_gvs = [(tf.clip_by_norm(grad,
                                            self.args.max_grad_norm), var)
                           for grad, var in gvs]
            self.train_op = self.optimizer.apply_gradients(clipped_gvs)

    def _add_summary(self):
        tf.summary.scalar('Loss', self.loss)
        self.tensorboard_writer = tf.summary.FileWriter(
            logdir=self.args.tensorboard_dir, graph=self.sess.graph)

        model_vars = tf.trainable_variables()

        total_size = 0
        total_bytes = 0
        model_msg = ""
        for var in model_vars:
            # if var.num_elements() is None or [] assume size 0.
            var_size = var.get_shape().num_elements() or 0
            var_bytes = var_size * var.dtype.size
            total_size += var_size
            total_bytes += var_bytes
            model_msg += ' '.join([
                var.name,
                tensor_description(var),
                '[%d, bytes: %d]' % (var_size, var_bytes)
            ])
            model_msg += '\n'
        model_msg += 'Total size of variables: %d \n' % total_size
        model_msg += 'Total bytes of variables: %d \n' % total_bytes
        logger.info(model_msg)
예제 #11
0
파일: model.py 프로젝트: renhongkai/DKVMN
    def sym_gen(self):
        ### TODO input variable 'q_data'
        q_data = mx.sym.Variable('q_data', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size)
        ### TODO input variable 'qa_data'
        qa_data = mx.sym.Variable('qa_data', shape=(self.seqlen, self.batch_size))  # (seqlen, batch_size)
        ### TODO input variable 'target'
        target = mx.sym.Variable('target', shape=(self.seqlen, self.batch_size)) #(seqlen, batch_size)

        ### Initialize Memory
        init_memory_key = mx.sym.Variable('init_memory_key_weight')
        init_memory_value = mx.sym.Variable('init_memory_value',
                                            shape=(self.memory_size, self.memory_value_state_dim),
                                            init=mx.init.Normal(0.1)) # (self.memory_size, self.memory_value_state_dim)
        init_memory_value = mx.sym.broadcast_to(mx.sym.expand_dims(init_memory_value, axis=0),
                                                shape=(self.batch_size, self.memory_size, self.memory_value_state_dim))

        mem = DKVMN(memory_size=self.memory_size,
                   memory_key_state_dim=self.memory_key_state_dim,
                   memory_value_state_dim=self.memory_value_state_dim,
                   init_memory_key=init_memory_key,
                   init_memory_value=init_memory_value,
                   name="DKVMN")


        ### embedding
        q_data = mx.sym.BlockGrad(q_data)
        q_embed_data = mx.sym.Embedding(data=q_data, input_dim=self.n_question+1,
                                        output_dim=self.q_embed_dim, name='q_embed')
        slice_q_embed_data = mx.sym.SliceChannel(q_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True)

        qa_data = mx.sym.BlockGrad(qa_data)
        qa_embed_data = mx.sym.Embedding(data=qa_data, input_dim=self.n_question*2+1,
                                         output_dim=self.qa_embed_dim, name='qa_embed')
        slice_qa_embed_data = mx.sym.SliceChannel(qa_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True)

        value_read_content_l = []
        input_embed_l = []
        for i in range(self.seqlen):
            ## Attention
            q = slice_q_embed_data[i]
            correlation_weight = mem.attention(q)

            ## Read Process
            read_content = mem.read(correlation_weight) #Shape (batch_size, memory_state_dim)
            ### save intermedium data
            value_read_content_l.append(read_content)
            input_embed_l.append(q)

            ## Write Process
            qa = slice_qa_embed_data[i]
            new_memory_value = mem.write(correlation_weight, qa)

        all_read_value_content = mx.sym.Concat(*value_read_content_l, num_args=self.seqlen, dim=0)

        input_embed_content = mx.sym.Concat(*input_embed_l, num_args=self.seqlen, dim=0)
        input_embed_content = mx.sym.FullyConnected(data=input_embed_content, num_hidden=50, name="input_embed_content")
        input_embed_content = mx.sym.Activation(data=input_embed_content, act_type='tanh', name="input_embed_content_tanh")

        read_content_embed = mx.sym.FullyConnected(data=mx.sym.Concat(all_read_value_content, input_embed_content, num_args=2, dim=1),
                                                   num_hidden=self.final_fc_dim, name="read_content_embed")
        read_content_embed = mx.sym.Activation(data=read_content_embed, act_type='tanh', name="read_content_embed_tanh")

        pred = mx.sym.FullyConnected(data=read_content_embed, num_hidden=1, name="final_fc")

        pred_prob = logistic_regression_mask_output(data=mx.sym.Reshape(pred, shape=(-1, )),
                                                    label=mx.sym.Reshape(data=target, shape=(-1,)),
                                                    ignore_label=-1., name='final_pred')
        return mx.sym.Group([pred_prob])