class MODEL(nn.Module): def __init__(self, n_question, batch_size, q_embed_dim, qa_embed_dim, memory_size, memory_key_state_dim, memory_value_state_dim, final_fc_dim, student_num=None): super(MODEL, self).__init__() self.n_question = n_question self.batch_size = batch_size self.q_embed_dim = q_embed_dim self.qa_embed_dim = qa_embed_dim self.memory_size = memory_size self.memory_key_state_dim = memory_key_state_dim self.memory_value_state_dim = memory_value_state_dim self.final_fc_dim = final_fc_dim self.student_num = student_num self.input_embed_linear = nn.Linear(self.q_embed_dim, self.q_embed_dim, bias=True) self.read_embed_linear = nn.Linear(self.memory_value_state_dim + self.q_embed_dim, self.final_fc_dim, bias=True) self.predict_linear = nn.Linear(self.final_fc_dim, 1, bias=True) self.init_memory_key = nn.Parameter( torch.randn(self.memory_size, self.memory_key_state_dim)) nn.init.kaiming_normal_(self.init_memory_key) self.init_memory_value = nn.Parameter( torch.randn(self.memory_size, self.memory_value_state_dim)) nn.init.kaiming_normal_(self.init_memory_value) self.mem = DKVMN(memory_size=self.memory_size, memory_key_state_dim=self.memory_key_state_dim, memory_value_state_dim=self.memory_value_state_dim, init_memory_key=self.init_memory_key) memory_value = nn.Parameter( torch.cat([ self.init_memory_value.unsqueeze(0) for _ in range(batch_size) ], 0).data) self.mem.init_value_memory(memory_value) # 题目序号从1开始 # nn.embedding输入是一个下标的列标,输出是对应的嵌入 self.q_embed = nn.Embedding(self.n_question + 1, self.q_embed_dim, padding_idx=0) self.qa_embed = nn.Embedding(2 * self.n_question + 1, self.qa_embed_dim, padding_idx=0) def init_params(self): nn.init.kaiming_normal_(self.predict_linear.weight) nn.init.kaiming_normal_(self.read_embed_linear.weight) nn.init.constant_(self.read_embed_linear.bias, 0) nn.init.constant_(self.predict_linear.bias, 0) # nn.init.constant(self.input_embed_linear.bias, 0) # nn.init.normal(self.input_embed_linear.weight, std=0.02) def init_embeddings(self): nn.init.kaiming_normal_(self.q_embed.weight) nn.init.kaiming_normal_(self.qa_embed.weight) def forward(self, q_data, qa_data, target, student_id=None): batch_size = q_data.shape[0] #32 seqlen = q_data.shape[1] #200 ## qt && (q,a) embedding q_embed_data = self.q_embed(q_data) qa_embed_data = self.qa_embed(qa_data) ## copy mk batch times for dkvmn memory_value = nn.Parameter( torch.cat([ self.init_memory_value.unsqueeze(0) for _ in range(batch_size) ], 0).data) self.mem.init_value_memory(memory_value) ## slice data for seqlen times by axis 1 # torch.chunk(tensor, chunk_num, dim) slice_q_data = torch.chunk(q_data, seqlen, 1) slice_q_embed_data = torch.chunk(q_embed_data, seqlen, 1) slice_qa_embed_data = torch.chunk(qa_embed_data, seqlen, 1) value_read_content_l = [] input_embed_l = [] for i in range(seqlen): ## Attention q = slice_q_embed_data[i].squeeze(1) correlation_weight = self.mem.attention(q) ## Read Process read_content = self.mem.read(correlation_weight) ## save intermedium data value_read_content_l.append(read_content) input_embed_l.append(q) ## Write Process qa = slice_qa_embed_data[i].squeeze(1) new_memory_value = self.mem.write(correlation_weight, qa) # Projection all_read_value_content = torch.cat( [value_read_content_l[i].unsqueeze(1) for i in range(seqlen)], 1) input_embed_content = torch.cat( [input_embed_l[i].unsqueeze(1) for i in range(seqlen)], 1) ## Project rt input_embed_content = input_embed_content.view(batch_size * seqlen, -1) input_embed_content = torch.tanh( self.input_embed_linear(input_embed_content)) input_embed_content = input_embed_content.view(batch_size, seqlen, -1) ## Concat Read_Content and input_embedding_value predict_input = torch.cat( [all_read_value_content, input_embed_content], 2) read_content_embed = torch.tanh( self.read_embed_linear(predict_input.view(batch_size * seqlen, -1))) pred = self.predict_linear(read_content_embed) # predicts = torch.cat([predict_logs[i] for i in range(seqlen)], 1) target_1d = target # [batch_size * seq_len, 1] # mask = target_1d.ge(0) # [batch_size * seq_len, 1] mask = q_data.gt(0).view(-1, 1) # pred_1d = predicts.view(-1, 1) # [batch_size * seq_len, 1] pred_1d = pred.view(-1, 1) # [batch_size * seq_len, 1] filtered_pred = torch.masked_select(pred_1d, mask) filtered_target = torch.masked_select(target_1d, mask) loss = torch.nn.functional.binary_cross_entropy_with_logits( filtered_pred, filtered_target) return loss, torch.sigmoid(filtered_pred), filtered_target
class MODEL(nn.Module): def __init__(self, n_question, batch_size, q_embed_dim, qa_embed_dim, memory_size, memory_key_state_dim, memory_value_state_dim, final_fc_dim, first_k, gpu, student_num=None): super(MODEL, self).__init__() self.n_question = n_question self.batch_size = batch_size self.q_embed_dim = q_embed_dim self.qa_embed_dim = qa_embed_dim self.memory_size = memory_size self.memory_key_state_dim = memory_key_state_dim self.memory_value_state_dim = memory_value_state_dim self.final_fc_dim = final_fc_dim self.student_num = student_num self.first_k = first_k self.read_embed_linear = nn.Linear(self.memory_value_state_dim + self.q_embed_dim, self.final_fc_dim, bias=True) # self.predict_linear = nn.Linear(self.memory_value_state_dim + self.q_embed_dim, 1, bias=True) self.init_memory_key = nn.Parameter( torch.randn(self.memory_size, self.memory_key_state_dim)) nn.init.kaiming_normal_(self.init_memory_key) self.init_memory_value = nn.Parameter( torch.randn(self.memory_size, self.memory_value_state_dim)) nn.init.kaiming_normal_(self.init_memory_value) # modify hop_lstm self.hop_lstm = nn.LSTM(input_size=self.memory_value_state_dim + self.q_embed_dim, hidden_size=64, num_layers=1, batch_first=True) # hidden_size = 64 self.predict_linear = nn.Linear(64, 1, bias=True) self.mem = DKVMN(memory_size=self.memory_size, memory_key_state_dim=self.memory_key_state_dim, memory_value_state_dim=self.memory_value_state_dim, init_memory_key=self.init_memory_key) memory_value = nn.Parameter( torch.cat([ self.init_memory_value.unsqueeze(0) for _ in range(batch_size) ], 0).data) self.mem.init_value_memory(memory_value) # 题目序号从1开始 # nn.embedding输入是一个下标的列标,输出是对应的嵌入 self.q_embed = nn.Embedding(self.n_question + 1, self.q_embed_dim, padding_idx=0) self.a_embed = nn.Linear(2 * self.n_question + 1, self.qa_embed_dim, bias=True) # self.a_embed = nn.Linear(self.final_fc_dim + 1, self.qa_embed_dim, bias=True) # self.correlation_weight_list = [] if gpu >= 0: self.device = torch.device('cuda', gpu) else: self.device = torch.device('cpu') print( "num_layers=1, hidden_size=64, a=0.075, b=0.088, c=1.00, triangular, onehot" ) def init_params(self): nn.init.kaiming_normal_(self.predict_linear.weight) nn.init.kaiming_normal_(self.read_embed_linear.weight) nn.init.constant_(self.read_embed_linear.bias, 0) nn.init.constant_(self.predict_linear.bias, 0) def init_embeddings(self): nn.init.kaiming_normal_(self.q_embed.weight) # 方法2:权重向量的topk置1 def identity_layer(self, correlation_weight, seqlen, k=1): batch_identity_indices = [] correlation_weight = correlation_weight.view(self.batch_size * seqlen, -1) # 把batch中每一格sequence中topk置1,其余置0 _, indices = correlation_weight.topk(k, dim=1, largest=True) identity_matrix = torch.zeros( [self.batch_size * seqlen, self.memory_size]) for i, m in enumerate(indices): identity_matrix[i, m] = 1 identity_vector_batch = identity_matrix.view(self.batch_size * seqlen, -1) unique_iv = torch.unique(identity_vector_batch, sorted=False, dim=0) self.unique_len = unique_iv.shape[0] # A^2 iv_square_norm = torch.sum(torch.pow(identity_vector_batch, 2), dim=1, keepdim=True) iv_square_norm = iv_square_norm.repeat((1, self.unique_len)) # B^2.T unique_iv_square_norm = torch.sum(torch.pow(unique_iv, 2), dim=1, keepdim=True) unique_iv_square_norm = unique_iv_square_norm.repeat( (1, self.batch_size * seqlen)).transpose(1, 0) # A * B.T iv_matrix_product = identity_vector_batch.mm(unique_iv.transpose(1, 0)) # A^2 + B^2 - 2A*B.T iv_distances = iv_square_norm + unique_iv_square_norm - 2 * iv_matrix_product indices = (iv_distances == 0).nonzero() batch_identity_indices = indices[:, -1] return batch_identity_indices # 方法1:用三角隶属函数计算identity向量 def triangular_layer(self, correlation_weight, seqlen, a=0.075, b=0.088, c=1.00): batch_identity_indices = [] # w'= max((w-a)/(b-a), (c-w)/(c-b)) # min(w', 0) correlation_weight = correlation_weight.view(self.batch_size * seqlen, -1) correlation_weight = torch.cat([ correlation_weight[i] for i in range(correlation_weight.shape[0]) ], 0).unsqueeze(0) correlation_weight = torch.cat([(correlation_weight - a) / (b - a), (c - correlation_weight) / (c - b)], 0) correlation_weight, _ = torch.min(correlation_weight, 0) w0 = torch.zeros(correlation_weight.shape[0]).to(self.device) correlation_weight = torch.cat( [correlation_weight.unsqueeze(0), w0.unsqueeze(0)], 0) correlation_weight, _ = torch.max(correlation_weight, 0) identity_vector_batch = torch.zeros(correlation_weight.shape[0]).to( self.device) # >=0.6的值置2,0.1-0.6的值置1,0.1以下的值置0 # mask = correlation_weight.lt(0.1) identity_vector_batch = identity_vector_batch.masked_fill( correlation_weight.lt(0.1), 0) # mask = correlation_weight.ge(0.1) identity_vector_batch = identity_vector_batch.masked_fill( correlation_weight.ge(0.1), 1) # mask = correlation_weight.ge(0.6) _identity_vector_batch = identity_vector_batch.masked_fill( correlation_weight.ge(0.6), 2) # identity_vector_batch = torch.chunk(identity_vector_batch.view(self.batch_size, -1), self.batch_size, 0) # 输入:_identity_vector_batch # 输出:indices identity_vector_batch = _identity_vector_batch.view( self.batch_size * seqlen, -1) unique_iv = torch.unique(identity_vector_batch, sorted=False, dim=0) self.unique_len = unique_iv.shape[0] # A^2 iv_square_norm = torch.sum(torch.pow(identity_vector_batch, 2), dim=1, keepdim=True) iv_square_norm = iv_square_norm.repeat((1, self.unique_len)) # B^2.T unique_iv_square_norm = torch.sum(torch.pow(unique_iv, 2), dim=1, keepdim=True) unique_iv_square_norm = unique_iv_square_norm.repeat( (1, self.batch_size * seqlen)).transpose(1, 0) # A * B.T iv_matrix_product = identity_vector_batch.mm(unique_iv.transpose(1, 0)) # A^2 + B^2 - 2A*B.T iv_distances = iv_square_norm + unique_iv_square_norm - 2 * iv_matrix_product indices = (iv_distances == 0).nonzero() batch_identity_indices = indices[:, -1] return batch_identity_indices def forward(self, q_data, qa_data, a_data, target, student_id=None): batch_size = q_data.shape[0] #32 seqlen = q_data.shape[1] #200 ## qt && (q,a) embedding q_embed_data = self.q_embed(q_data) # modify 生成每道题对应的yt onehot向量 a_onehot_array = [] for i in range(a_data.shape[0]): for j in range(a_data.shape[1]): a_onehot = np.zeros(self.n_question + 1) index = a_data[i][j] if index > 0: a_onehot[index] = 1 a_onehot_array.append(a_onehot) a_onehot_content = torch.cat([ torch.Tensor(a_onehot_array[i]).unsqueeze(0) for i in range(len(a_onehot_array)) ], 0) a_onehot_content = a_onehot_content.view(batch_size, seqlen, -1).to(self.device) ## copy mk batch times for dkvmn memory_value = nn.Parameter( torch.cat([ self.init_memory_value.unsqueeze(0) for _ in range(batch_size) ], 0).data) self.mem.init_value_memory(memory_value) ## slice data for seqlen times by axis 1 slice_q_data = torch.chunk(q_data, seqlen, 1) slice_q_embed_data = torch.chunk(q_embed_data, seqlen, 1) # modify slice_a_onehot_content = torch.chunk(a_onehot_content, seqlen, 1) # slice_a = torch.chunk(a_data, seqlen, 1) value_read_content_l = [] input_embed_l = [] correlation_weight_list = [] # modify f_t = [] # (n_layers,batch_size,hidden_dim) init_h = torch.randn(1, self.batch_size, 64).to(self.device) init_c = torch.randn(1, self.batch_size, 64).to(self.device) for i in range(seqlen): ## Attention q = slice_q_embed_data[i].squeeze(1) correlation_weight = self.mem.attention(q) ## Read Process read_content = self.mem.read(correlation_weight) # modify correlation_weight_list.append(correlation_weight) ## save intermedium data value_read_content_l.append(read_content) input_embed_l.append(q) # modify batch_predict_input = torch.cat([read_content, q], 1) f = self.read_embed_linear(batch_predict_input) f_t.append(batch_predict_input) # 写入value矩阵的输入为[yt, ft],onehot向量和ft向量拼接 onehot = slice_a_onehot_content[i].squeeze(1) write_embed = torch.cat([onehot, f], 1) # 写入value矩阵的输入为[ft, yt],ft直接和题目对错(0或1)拼接 # write_embed = torch.cat([f, slice_a[i].float()], 1) write_embed = self.a_embed(write_embed) new_memory_value = self.mem.write(correlation_weight, write_embed) # modify correlation_weight_matrix = torch.cat( [correlation_weight_list[i].unsqueeze(1) for i in range(seqlen)], 1) identity_index_list = self.triangular_layer(correlation_weight_matrix, seqlen) # identity_index_list = self.identity_layer(correlation_weight_matrix, seqlen) identity_index_list = identity_index_list.view(self.batch_size, seqlen) # identity_index_list = identity_index_list[:, self.first_k:] # 前k个不进行预测 # identity_index_list = torch.cat([identity_index_list[i].unsqueeze(1) for i in range(seqlen)], 1) f_t = torch.cat([f_t[i].unsqueeze(1) for i in range(seqlen)], 1) # f_t = f_t[:, self.first_k:] # 前k个不进行预测 target_seqlayer = target.view(batch_size, seqlen, -1) # target_seqlayer = target_seqlayer[:, self.first_k:] # 前k个不进行预测 target_sequence = [] pred_sequence = [] for idx in range(self.unique_len): # start = time.time() hop_lstm_input = [] hop_lstm_target = [] max_seq = 1 zero_count = 0 for i in range(self.batch_size): # 获取每个sequence中和当前要进行预测的identity向量对应的题目在矩阵中的index index = list((identity_index_list[i, :] == idx).nonzero()) max_seq = max(max_seq, len(index)) if len(index) == 0: hop_lstm_input.append( torch.zeros([ 1, self.memory_value_state_dim + self.q_embed_dim ])) hop_lstm_target.append(torch.full([1, 1], -1)) zero_count += 1 continue else: index = torch.LongTensor(index).to(self.device) hop_lstm_target_slice = torch.index_select( target_seqlayer[i, :, :], 0, index) hop_lstm_input_slice = torch.index_select( f_t[i, :, :], 0, index) hop_lstm_input.append(hop_lstm_input_slice) hop_lstm_target.append(hop_lstm_target_slice) if zero_count == 32: continue # 给输入矩阵和target矩阵做padding for i in range(self.batch_size): x = torch.zeros( [max_seq, self.memory_value_state_dim + self.q_embed_dim]) x[:len(hop_lstm_input[i]), :] = hop_lstm_input[i] hop_lstm_input[i] = x y = torch.full([max_seq, 1], -1) y[:len(hop_lstm_target[i]), :] = hop_lstm_target[i] hop_lstm_target[i] = y # hop lstm进行预测 hop_lstm_input = torch.cat([ hop_lstm_input[i].unsqueeze(0) for i in range(self.batch_size) ], 0).to(self.device) hop_lstm_target = torch.cat([ hop_lstm_target[i].unsqueeze(0) for i in range(self.batch_size) ], 0) hop_lstm_output, _ = self.hop_lstm(hop_lstm_input, (init_h, init_c)) pred = self.predict_linear(hop_lstm_output) pred = pred.view(self.batch_size * max_seq, -1) hop_lstm_target = hop_lstm_target.view(self.batch_size * max_seq, -1).to(self.device) mask = hop_lstm_target.ge(0) hop_lstm_target = torch.masked_select(hop_lstm_target, mask) pred = torch.sigmoid(torch.masked_select(pred, mask)) target_sequence.append(hop_lstm_target) pred_sequence.append(pred) # 在训练阶段对每个identity向量对应的lstm分别进行反向传播 if self.training is True: subsequence_loss = torch.nn.functional.binary_cross_entropy_with_logits( pred, hop_lstm_target) subsequence_loss.backward(retain_graph=True) # 计算一个batch全部题目的loss target_sequence = torch.cat( [target_sequence[i] for i in range(len(target_sequence))], 0) pred_sequence = torch.cat( [pred_sequence[i] for i in range(len(pred_sequence))], 0) loss = torch.nn.functional.binary_cross_entropy_with_logits( pred_sequence, target_sequence) return loss, pred_sequence, target_sequence
def sym_gen(self): ### TODO input variable 'q_data' q_data = mx.sym.Variable('q_data', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size) ### TODO input variable 'qa_data' qa_data = mx.sym.Variable('qa_data', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size) ### TODO input variable 'target' target = mx.sym.Variable('target', shape=(self.seqlen, self.batch_size)) #(seqlen, batch_size) ### Initialize Memory init_memory_key = mx.sym.Variable('init_memory_key_weight') init_memory_value = mx.sym.Variable('init_memory_value', shape=(self.memory_size, self.memory_value_state_dim), init=mx.init.Normal(0.1)) # (self.memory_size, self.memory_value_state_dim) init_memory_value = mx.sym.broadcast_to(mx.sym.expand_dims(init_memory_value, axis=0), shape=(self.batch_size, self.memory_size, self.memory_value_state_dim)) mem = DKVMN(memory_size=self.memory_size, memory_key_state_dim=self.memory_key_state_dim, memory_value_state_dim=self.memory_value_state_dim, init_memory_key=init_memory_key, init_memory_value=init_memory_value, name="DKVMN") ### embedding q_data = mx.sym.BlockGrad(q_data) q_embed_data = mx.sym.Embedding(data=q_data, input_dim=self.n_question+1, output_dim=self.q_embed_dim, name='q_embed') slice_q_embed_data = mx.sym.SliceChannel(q_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True) qa_data = mx.sym.BlockGrad(qa_data) qa_embed_data = mx.sym.Embedding(data=qa_data, input_dim=self.n_question*2+1, output_dim=self.qa_embed_dim, name='qa_embed') slice_qa_embed_data = mx.sym.SliceChannel(qa_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True) value_read_content_l = [] input_embed_l = [] readDict = {object :[]} for i in range(self.seqlen): ## Attention q = mx.sym.L2Normalization(slice_q_embed_data[i], mode='instance') correlation_weight = mem.attention(q) ## Read Process read_content = mem.read(correlation_weight) #Shape (batch_size, memory_state_dim) ### save intermedium data [OLD] value_read_content_l.append(read_content) input_embed_l.append(q) ## Write Process #qa = slice_qa_embed_data[i] #new_memory_value = mem.write(correlation_weight, mx.sym.Concat(qa , read_content)) qa = mx.sym.concat(mx.sym.L2Normalization(read_content, mode='instance'),mx.sym.L2Normalization(slice_qa_embed_data[i], mode='instance')) #qa=mx.sym.L2Normalization(slice_qa_embed_data[i], mode='instance') new_memory_value = mem.write(correlation_weight,qa) #================================[ Cluster related read_contents based on fuzzy representation] ============== for i in range(0,len(value_read_content_l)): current_fuzz_rep = mx.symbol.Custom(data=value_read_content_l[i], name='fuzzkey', op_type='fuzzify') related = [value_read_content_l[i]] for j in range(0,len(value_read_content_l)): if i != j: tmp_fuzz = mx.symbol.Custom(data=value_read_content_l[j], name='fuzzkey', op_type='fuzzify') if current_fuzz_rep.tojson() == tmp_fuzz.tojson(): related.append(value_read_content_l[j]) value_read_content_l[i] = mx.sym.Reshape(data=mx.sym.RNN(data=related,state_size=self.memory_value_state_dim,num_layers=2,mode ='lstm',p =0.2), # Shape (batch_size, 1, memory_state_dim) shape=(-1,self.memory_value_state_dim)) #================================================================================= all_read_value_content = mx.sym.Concat(*value_read_content_l, num_args=self.seqlen, dim=0) input_embed_content = mx.sym.Concat(*input_embed_l, num_args=self.seqlen, dim=0) input_embed_content = mx.sym.FullyConnected(data=mx.sym.L2Normalization(input_embed_content, mode='instance'), num_hidden=64, name="input_embed_content") input_embed_content = mx.sym.Activation(data=mx.sym.L2Normalization(input_embed_content, mode='instance'), act_type='tanh', name="input_embed_content_tanh") read_content_embed = mx.sym.FullyConnected(data=mx.sym.Concat(mx.sym.L2Normalization(all_read_value_content, mode='instance'), mx.sym.L2Normalization(input_embed_content, mode='instance'), num_args=2, dim=1), num_hidden=self.final_fc_dim, name="read_content_embed") read_content_embed = mx.sym.Activation(data= mx.sym.L2Normalization(read_content_embed, mode='instance'), act_type='tanh', name="read_content_embed_tanh") #================================================[ Updated for F value]==================================== # for i in range(self.seqlen): # ## Write Process # qa = mx.symbol.batch_dot(slice_qa_embed_data[i],read_content_embed) # #qa = mx.sym.Concat(slice_qa_embed_data[i],read_content_embed) # #qa = read_content_embed # new_memory_value = mem.write(correlation_weight, qa) #========================================================================================================== pred = mx.sym.FullyConnected(data=mx.sym.L2Normalization(read_content_embed, mode='instance'), num_hidden=1, name="final_fc") pred_prob = logistic_regression_mask_output(data=mx.sym.Reshape(pred, shape=(-1, )), label=mx.sym.Reshape(data=target, shape=(-1,)), ignore_label=-1., name='final_pred') return mx.sym.Group([pred_prob])
class Model(): def __init__(self, args, sess, name='KT'): self.args = args self.name = name self.sess = sess self.sess.run(tf.global_variables_initializer()) self.create_model() # if self.load(): # print('CKPT Loaded') # else: # raise Exception('CKPT need') def create_model(self): # 'seq_len' means question sequences self.q_data = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len], name='q_data') self.qa_data = tf.placeholder( tf.int32, [self.args.batch_size, self.args.seq_len], name='qa_data') self.target = tf.placeholder(tf.float32, [self.args.batch_size, self.args.seq_len], name='target') self.kg = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len, 3], name='knowledge_tag') self.kg_hot = tf.placeholder( tf.float32, [self.args.batch_size, self.args.seq_len, 188], name='knowledge_hot') self.timebin = tf.placeholder( tf.int32, [self.args.batch_size, self.args.seq_len]) self.diff = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len]) self.guan = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len]) with tf.variable_scope('Memory'): init_memory_key = tf.get_variable('key', [self.args.memory_size, self.args.memory_key_state_dim], \ initializer=tf.truncated_normal_initializer(stddev=0.1)) init_memory_value = tf.get_variable('value', [self.args.memory_size,self.args.memory_value_state_dim], \ initializer=tf.truncated_normal_initializer(stddev=0.1)) with tf.variable_scope('time'): time_embed_mtx = tf.get_variable('timebin', [12, self.args.memory_value_state_dim],\ initializer=tf.truncated_normal_initializer(stddev=0.1)) with tf.variable_scope('diff'): guan_embed_mtx = tf.get_variable('diff', [12, self.args.memory_value_state_dim],\ initializer=tf.truncated_normal_initializer(stddev=0.1)) with tf.variable_scope('gate'): diff_embed_mtx = tf.get_variable('gate', [12, self.args.memory_value_state_dim],\ initializer=tf.truncated_normal_initializer(stddev=0.1)) init_memory_value = tf.tile(tf.expand_dims(init_memory_value, 0), tf.stack([self.args.batch_size, 1, 1])) print(init_memory_value.get_shape()) self.memory = DKVMN(self.args.memory_size, self.args.memory_key_state_dim, \ self.args.memory_value_state_dim, init_memory_key=init_memory_key, init_memory_value=init_memory_value, batch_size=self.args.batch_size, name='DKVMN') with tf.variable_scope('Embedding'): # A q_embed_mtx = tf.get_variable('q_embed', [self.args.n_questions+1, self.args.memory_key_state_dim],\ initializer=tf.truncated_normal_initializer(stddev=0.1)) # B qa_embed_mtx = tf.get_variable( 'qa_embed', [ 2 * self.args.n_questions + 1, self.args.memory_value_state_dim ], initializer=tf.truncated_normal_initializer(stddev=0.1)) q_embed_data = tf.nn.embedding_lookup(q_embed_mtx, self.q_data) slice_q_embed_data = tf.split(q_embed_data, self.args.seq_len, 1) qa_embed_data = tf.nn.embedding_lookup(qa_embed_mtx, self.qa_data) slice_qa_embed_data = tf.split(qa_embed_data, self.args.seq_len, 1) time_embedding = tf.nn.embedding_lookup(time_embed_mtx, self.timebin) slice_time_embedding = tf.split(time_embedding, self.args.seq_len, 1) guan_embedding = tf.nn.embedding_lookup(diff_embed_mtx, self.diff) slice_guan_embedding = tf.split(guan_embedding, self.args.seq_len, 1) diff_embedding = tf.nn.embedding_lookup(diff_embed_mtx, self.diff) slice_diff_embedding = tf.split(diff_embedding, self.args.seq_len, 1) slice_kg = tf.split(self.kg, self.args.seq_len, 1) slice_kg_hot = tf.split(self.kg_hot, self.args.seq_len, 1) reuse_flag = False prediction = list() # Logics for i in range(self.args.seq_len): # To reuse linear vectors if i != 0: reuse_flag = True q = tf.squeeze(slice_q_embed_data[i], 1) qa = tf.squeeze(slice_qa_embed_data[i], 1) kg = tf.squeeze(slice_kg[i], 1) kg_hot = tf.squeeze(slice_kg_hot[i], 1) dotime = tf.squeeze(slice_time_embedding[i], 1) dodiff = tf.squeeze(slice_diff_embedding[i], 1) doguan = tf.squeeze(slice_guan_embedding[i], 1) self.correlation_weight = self.memory.attention(q, kg, kg_hot) # # Read process, [batch size, memory value state dim] self.read_content = self.memory.read(self.correlation_weight) mastery_level_prior_difficulty = tf.concat( [self.read_content, q, doguan], 1) # f_t summary_vector = tf.tanh( operations.linear(mastery_level_prior_difficulty, self.args.final_fc_dim, name='Summary_Vector', reuse=reuse_flag)) # p_t pred_logits = operations.linear(summary_vector, 1, name='Prediction', reuse=reuse_flag) prediction.append(pred_logits) qa_time = tf.concat([qa, dotime], axis=1) self.new_memory_value = self.memory.write(self.correlation_weight, qa_time, reuse=reuse_flag) # 'prediction' : seq_len length list of [batch size ,1], make it [batch size, seq_len] tensor # tf.stack convert to [batch size, seq_len, 1] self.pred_logits = tf.reshape(tf.stack( prediction, axis=1), [self.args.batch_size, self.args.seq_len]) # Define loss : standard cross entropy loss, need to ignore '-1' label example # Make target/label 1-d array target_1d = tf.reshape(self.target, [-1]) pred_logits_1d = tf.reshape(self.pred_logits, [-1]) index = tf.where( tf.not_equal(target_1d, tf.constant(-1., dtype=tf.float32))) # tf.gather(params, indices) : Gather slices from params according to indices filtered_target = tf.gather(target_1d, index) filtered_logits = tf.gather(pred_logits_1d, index) self.loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=filtered_logits, labels=filtered_target)) self.pred = tf.sigmoid(self.pred_logits) # Optimizer : SGD + MOMENTUM with learning rate decay self.global_step = tf.Variable(0, trainable=False) self.lr = tf.placeholder(tf.float32, [], name='learning_rate') optimizer = tf.train.MomentumOptimizer(self.lr, self.args.momentum) grads, vrbs = zip(*optimizer.compute_gradients(self.loss)) grad, _ = tf.clip_by_global_norm(grads, self.args.maxgradnorm) self.train_op = optimizer.apply_gradients(zip(grad, vrbs), global_step=self.global_step) self.tr_vrbs = tf.trainable_variables() self.params = {} for i in self.tr_vrbs: print(i.name) self.params[i.name] = tf.get_default_graph().get_tensor_by_name( i.name) self.saver = tf.train.Saver() def getParam(self): """ Get parameters of DKVMN-CA model :return: """ params = self.sess.run([self.params]) with open('good_future.pkl', 'wb') as f: pickle.dump(params, f) def train(self, train_q_data, train_qa_data, valid_q_data, valid_qa_data, train_kg_data, valid_kg_data, train_kgnum_data, valid_kgnum_data, traintime, validtime, trainguan, validguan, traindiff, validdiff): """ :param train_q_data: exercises ID :param train_qa_data: exercises ID and answer result :param train_kg_data: one-hot form of knowledge concepts :param train_kgnum_data: knowledge concepts tags :param traintime: completion time :param trainguan: the gate of exercise :param traindiff: the difficulty of exercise """ shuffle_index = np.random.permutation(train_q_data.shape[0]) q_data_shuffled = train_q_data[shuffle_index] qa_data_shuffled = train_qa_data[shuffle_index] kg_shuffled = train_kgnum_data[shuffle_index] kghot_shuffled = train_kg_data[shuffle_index] time_shuffled = traintime[shuffle_index] guan_shuffled = trainguan[shuffle_index] diff_shuffled = traindiff[shuffle_index] training_step = train_q_data.shape[0] // self.args.batch_size self.sess.run(tf.global_variables_initializer()) self.train_count = 0 if self.args.init_from: if self.load(): print('Checkpoint_loaded') else: print('No checkpoint') else: if os.path.exists( os.path.join(self.args.checkpoint_dir, self.model_dir)): try: shutil.rmtree( os.path.join(self.args.checkpoint_dir, self.model_dir)) shutil.rmtree( os.path.join(self.args.log_dir, self.mode_dir + '.csv')) except (FileNotFoundError, IOError) as e: print('[Delete Error] %s - %s' % (e.filename, e.strerror)) best_valid_auc = 0 # Training for epoch in range(0, self.args.num_epochs): if self.args.show: bar.next() pred_list = list() target_list = list() epoch_loss = 0 for steps in range(training_step): # [batch size, seq_len] q_batch_seq = q_data_shuffled[steps * self.args.batch_size:(steps + 1) * self.args.batch_size, :] qa_batch_seq = qa_data_shuffled[steps * self.args.batch_size:(steps + 1) * self.args.batch_size, :] kg_batch_seq = kg_shuffled[steps * self.args.batch_size:(steps + 1) * self.args.batch_size, :] kghot_batch_seq = kghot_shuffled[steps * self.args.batch_size:(steps + 1) * self.args.batch_size, :] time_batch_seq = time_shuffled[steps * self.args.batch_size:(steps + 1) * self.args.batch_size, :] guan_batch_seq = guan_shuffled[steps * self.args.batch_size:(steps + 1) * self.args.batch_size, :] diff_batch_seq = diff_shuffled[steps * self.args.batch_size:(steps + 1) * self.args.batch_size, :] # qa : exercise index + answer(0 or 1)*exercies_number # right : 1, wrong : 0, padding : -1 target = qa_batch_seq[:, :] # Make integer type to calculate target target = target.astype(np.int) target_batch = (target - 1) // self.args.n_questions target_batch = target_batch.astype(np.float) feed_dict = { self.kg: kg_batch_seq, self.q_data: q_batch_seq, self.qa_data: qa_batch_seq, self.target: target_batch, self.kg_hot: kghot_batch_seq, self.lr: self.args.initial_lr, self.timebin: time_batch_seq, self.diff: diff_batch_seq, self.guan: guan_batch_seq } loss_, pred_, _, = self.sess.run( [self.loss, self.pred, self.train_op], feed_dict=feed_dict) right_target = np.asarray(target_batch).reshape(-1, 1) right_pred = np.asarray(pred_).reshape(-1, 1) right_index = np.flatnonzero(right_target != -1.).tolist() pred_list.append(right_pred[right_index]) target_list.append(right_target[right_index]) epoch_loss += loss_ if self.args.show: bar.finish() all_pred = np.concatenate(pred_list, axis=0) all_target = np.concatenate(target_list, axis=0) # Compute metrics self.auc = metrics.roc_auc_score(all_target, all_pred) # Extract elements with boolean index # Make '1' for elements higher than 0.5 # Make '0' for elements lower than 0.5 all_pred[all_pred > 0.5] = 1 all_pred[all_pred <= 0.5] = 0 self.accuracy = metrics.accuracy_score(all_target, all_pred) epoch_loss = epoch_loss / training_step print('Epoch %d/%d, loss : %3.5f, auc : %3.5f, accuracy : %3.5f' % (epoch + 1, self.args.num_epochs, epoch_loss, self.auc, self.accuracy)) self.write_log(epoch=epoch + 1, auc=self.auc, accuracy=self.accuracy, loss=epoch_loss, name='training_') valid_steps = valid_q_data.shape[0] // self.args.batch_size valid_pred_list = list() valid_target_list = list() for s in range(valid_steps): # Validation valid_q = valid_q_data[s * self.args.batch_size:(s + 1) * self.args.batch_size, :] valid_qa = valid_qa_data[s * self.args.batch_size:(s + 1) * self.args.batch_size, :] valid_kg = valid_kgnum_data[s * self.args.batch_size:(s + 1) * self.args.batch_size, :] valid_hot_kg = valid_kg_data[s * self.args.batch_size:(s + 1) * self.args.batch_size, :] valid_time = validtime[s * self.args.batch_size:(s + 1) * self.args.batch_size, :] valid_guan = validguan[s * self.args.batch_size:(s + 1) * self.args.batch_size, :] valid_diff = validdiff[s * self.args.batch_size:(s + 1) * self.args.batch_size, :] # right : 1, wrong : 0, padding : -1 valid_target = (valid_qa - 1) // self.args.n_questions valid_feed_dict = { self.kg: valid_kg, self.q_data: valid_q, self.qa_data: valid_qa, self.kg_hot: valid_hot_kg, self.target: valid_target, self.timebin: valid_time, self.guan: valid_guan, self.diff: valid_diff } valid_loss, valid_pred = self.sess.run( [self.loss, self.pred], feed_dict=valid_feed_dict) # Same with training set valid_right_target = np.asarray(valid_target).reshape(-1, ) valid_right_pred = np.asarray(valid_pred).reshape(-1, ) valid_right_index = np.flatnonzero( valid_right_target != -1).tolist() valid_target_list.append(valid_right_target[valid_right_index]) valid_pred_list.append(valid_right_pred[valid_right_index]) all_valid_pred = np.concatenate(valid_pred_list, axis=0) all_valid_target = np.concatenate(valid_target_list, axis=0) valid_auc = metrics.roc_auc_score(all_valid_target, all_valid_pred) # For validation accuracy stop = 0 all_valid_pred[all_valid_pred > 0.5] = 1 all_valid_pred[all_valid_pred <= 0.5] = 0 valid_accuracy = metrics.accuracy_score(all_valid_target, all_valid_pred) print('Epoch %d/%d, valid auc : %3.5f, valid accuracy : %3.5f' % (epoch + 1, self.args.num_epochs, valid_auc, valid_accuracy)) # Valid log self.write_log(epoch=epoch + 1, auc=valid_auc, accuracy=valid_accuracy, loss=valid_loss, name='valid_') if valid_auc > best_valid_auc: print('%3.4f to %3.4f' % (best_valid_auc, valid_auc)) best_valid_auc = valid_auc best_acc = valid_accuracy best_epoch = epoch + 1 # self.save(best_epoch) else: if epoch - best_epoch >= 2: with open(self.args.dataset + 'concat', 'a') as f: f.write('auc:' + str(best_valid_auc) + ',acc:' + str(best_acc) + '\n') self.args.count += 1 break @property def model_dir(self): return '{}_{}batch_{}epochs'.format( self.args.dataset + str(self.args.count), self.args.batch_size, self.args.num_epochs) def load(self): checkpoint_dir = os.path.join(self.args.checkpoint_dir, self.model_dir) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.train_count = int(ckpt_name.split('-')[-1]) self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) return True else: return False def save(self, global_step): model_name = 'DKVMN' checkpoint_dir = os.path.join(self.args.checkpoint_dir, self.model_dir) if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) self.saver.save(self.sess, os.path.join(checkpoint_dir, model_name), global_step=global_step) print('Save checkpoint at %d' % (global_step + 1)) # Log file def write_log(self, auc, accuracy, loss, epoch, name='training_'): log_path = os.path.join(self.args.log_dir, name + self.model_dir + '.csv') if not os.path.exists(log_path): self.log_file = open(log_path, 'w') self.log_file.write('Epoch\tAuc\tAccuracy\tloss\n') else: self.log_file = open(log_path, 'a') self.log_file.write( str(epoch) + '\t' + str(auc) + '\t' + str(accuracy) + '\t' + str(loss) + '\n') self.log_file.flush()
def sym_gen(self): ### TODO input variable 'q_data' q_data = mx.sym.Variable('q_data', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size) ### TODO input variable 'qa_data' qa_data = mx.sym.Variable('qa_data', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size) ### TODO input variable 'target' target = mx.sym.Variable('target', shape=(self.seqlen, self.batch_size)) #(seqlen, batch_size) ### Initialize Memory init_memory_key = mx.sym.Variable('init_memory_key_weight') init_memory_value = mx.sym.Variable('init_memory_value', shape=(self.memory_size, self.memory_value_state_dim), init=mx.init.Normal(0.1)) # (self.memory_size, self.memory_value_state_dim) init_memory_value = mx.sym.broadcast_to(mx.sym.expand_dims(init_memory_value, axis=0), shape=(self.batch_size, self.memory_size, self.memory_value_state_dim)) mem = DKVMN(memory_size=self.memory_size, memory_key_state_dim=self.memory_key_state_dim, memory_value_state_dim=self.memory_value_state_dim, init_memory_key=init_memory_key, init_memory_value=init_memory_value, name="DKVMN") ### embedding q_data = mx.sym.BlockGrad(q_data) q_embed_data = mx.sym.Embedding(data=q_data, input_dim=self.n_question+1, output_dim=self.q_embed_dim, name='q_embed') slice_q_embed_data = mx.sym.SliceChannel(q_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True) qa_data = mx.sym.BlockGrad(qa_data) qa_embed_data = mx.sym.Embedding(data=qa_data, input_dim=self.n_question*2+1, output_dim=self.qa_embed_dim, name='qa_embed') slice_qa_embed_data = mx.sym.SliceChannel(qa_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True) value_read_content_l = [] input_embed_l = [] for i in range(self.seqlen): ## Attention q = slice_q_embed_data[i] correlation_weight = mem.attention(q) ## Read Process read_content = mem.read(correlation_weight) #Shape (batch_size, memory_state_dim) ### save intermedium data value_read_content_l.append(read_content) input_embed_l.append(q) ## Write Process qa = slice_qa_embed_data[i] new_memory_value = mem.write(correlation_weight, qa) all_read_value_content = mx.sym.Concat(*value_read_content_l, num_args=self.seqlen, dim=0) input_embed_content = mx.sym.Concat(*input_embed_l, num_args=self.seqlen, dim=0) input_embed_content = mx.sym.FullyConnected(data=input_embed_content, num_hidden=50, name="input_embed_content") input_embed_content = mx.sym.Activation(data=input_embed_content, act_type='tanh', name="input_embed_content_tanh") read_content_embed = mx.sym.FullyConnected(data=mx.sym.Concat(all_read_value_content, input_embed_content, num_args=2, dim=1), num_hidden=self.final_fc_dim, name="read_content_embed") read_content_embed = mx.sym.Activation(data=read_content_embed, act_type='tanh', name="read_content_embed_tanh") pred = mx.sym.FullyConnected(data=read_content_embed, num_hidden=1, name="final_fc") pred_prob = logistic_regression_mask_output(data=mx.sym.Reshape(pred, shape=(-1, )), label=mx.sym.Reshape(data=target, shape=(-1,)), ignore_label=-1., name='final_pred') return mx.sym.Group([pred_prob, pred, read_content_embed,correlation_weight])
class MODEL(nn.Module): def __init__(self, n_question, batch_size, q_embed_dim, qa_embed_dim, memory_size, memory_key_state_dim, memory_value_state_dim, final_fc_dim, student_num=None): self.n_question = n_question self.batch_size = batch_size self.q_embed_dim = q_embed_dim self.qa_embed_dim = qa_embed_dim self.memory_size = memory_size self.memory_key_state_dim = memory_key_state_dim self.memory_value_state_dim = memory_value_state_dim self.final_fc_dim = final_fc_dim self.student_num = student_num self.input_embed_linear = nn.Linear(self.q_embed_dim, self.final_fc_dim, bias=True) self.read_embed_linear = nn.Linear(self.memory_value_state_dim + self.final_fc_dim, self.final_fc_dim, bias=True) self.predict_linear = nn.Linear(self.final_fc_dim, 1, bias=True) self.init_memory_key = nn.Parameter( torch.randn(self.memory_size, self.memory_key_state_dim)) nn.init.kaiming_normal(self.init_memory_key) self.init_memory_value = nn.Parameter( torch.randn(self.memory_size, self.memory_value_state_dim)) nn.init.kaiming_normal(self.init_memory_value) self.mem = DKVMN(self.memory_size, self.memory_key_state_dim, self.memory_value_state_dim, self.init_memory_key) memory_value = nn.Parameter( torch.cat([ self.init_memory_value.unsqueeze(0) for _ in range(batch_size) ], 0).data) self.mem.set_memory_value(memory_value) self.q_embed = nn.Embedding(self.n_question + 1, self.q_embed_dim, padding_idx=0) self.qa_embed = nn.Embedding(self.n_question * 2 + 1, self.qa_embed_dim, padding_idx=0) def init_params(self): nn.init.kaiming_normal(self.predict_linear.weight) nn.init.kaiming_normal(self.read_embed_linear.weight) nn.init.constant(self.read_embed_linear.bias, 0) nn.init.constant(self.predict_linear.bias, 0) def init_embeddings(self): nn.init.kaiming_normal(self.q_embed.weight) nn.init.kaiming_normal(self.qa_embed.weight) def forward(self, q_data, qa_data, target, student_id=None): batch_size = q_data.shpae[0] seqlen = q_data.shape[1] q_embed_data = self.q_embed(q_data) qa_embed_data = self.qa_embed(qa_data) memory_value = nn.Parameter( torch.cat([ self.init_memory_value.unsqueeze(0) for _ in range(batch_size) ], 0).data) self.mem.set_memory_value(memory_value) slice_q_data = torch.chunk(q_data, seqlen, 1) slice_q_embed_data = torch.chunk(q_embed_data, seqlen, 1) slice_qa_embed_data = torch.chunk(qa_embed_data, seqlen, 1) value_read_content_1 = [] input_embed_1 = [] predict_logs = [] for i in range(seqlen): # attention q = slice_q_embed_data[i].squeeze(1) correlation_weight = self.mem.attention(q) if_memory_write = slice_q_data[i].squeeze(1).ge(1) if_memory_write = utils.variable( torch.FloatTensor(if_memory_write.data.tolist()), 1) # read read_content = self.mem.read(correlation_weight) value_read_content_1.append(read_content) input_embed_1.append(q) # write qa = slice_qa_embed_data[i].squeeze(1) new_memory_value = self.mem.write(correlation_weight, qa) all_read_value_content = torch.cat( [value_read_content_1[i].squeeze(1) for i in range(seqlen)], 1) input_embed_content = torch.cat( [input_embed_1[i].squeeze(1) for i in range(seqlen)], 1) predict_input = torch.cat( [all_read_value_content, input_embed_content], 2) read_content_embed = torch.tanh( self.read_embed_linear( predict_input.reshape(batch_size * seqlen, -1))) pred = self.predict_linear(read_content_embed) target_1d = target mask = target_1d.ge(0) pred_1d = pred.reshape(-1, 1) filtered_pred = torch.masked_select(pred_1d, mask) filtered_target = torch.masked_select(target_1d, mask) loss = nn.functional.binary_cross_entropy_with_logits( filtered_pred, filtered_target) return loss, torch.sigmoid(filtered_pred), filtered_target
class Model(): def __init__(self, args, sess, name='KT'): self.args = args self.name = name self.sess = sess # Initialize Memory with tf.variable_scope('Memory'): init_memory_key = tf.get_variable('key', [self.args.memory_size, self.args.memory_key_state_dim], \ initializer=tf.truncated_normal_initializer(stddev=0.1)) init_memory_value = tf.get_variable('value', [self.args.memory_size,self.args.memory_value_state_dim], \ initializer=tf.truncated_normal_initializer(stddev=0.1)) # Broadcast memory value tensor to match [batch size, memory size, memory state dim] # First expand dim at axis 0 so that makes 'batch size' axis and tile it along 'batch size' axis # tf.tile(inputs, multiples) : multiples length must be thes saame as the number of dimensions in input # tf.stack takes a list and convert each element to a tensor init_memory_value = tf.tile(tf.expand_dims(init_memory_value, 0), tf.stack([self.args.batch_size, 1, 1])) print(init_memory_value.get_shape()) self.memory = DKVMN(self.args.memory_size, self.args.memory_key_state_dim, \ self.args.memory_value_state_dim, init_memory_key=init_memory_key, init_memory_value=init_memory_value, name='DKVMN') # Embedding to [batch size, seq_len, memory_state_dim(d_k or d_v)] with tf.variable_scope('Embedding'): # A self.q_embed_mtx = tf.get_variable('q_embed', [self.args.n_questions+1, self.args.memory_key_state_dim],\ initializer=tf.truncated_normal_initializer(stddev=0.1)) # B self.qa_embed_mtx = tf.get_variable('qa_embed', [2*self.args.n_questions+1, self.args.memory_value_state_dim], initializer=tf.truncated_normal_initializer(stddev=0.1)) self.prediction = self.build_network(reuse_flag=False) self.build_optimizer() #self.create_model() def build_network(self, reuse_flag): print('Building network') self.q_data = tf.placeholder(tf.int32, [self.args.batch_size], name='q_data') self.qa_data = tf.placeholder(tf.int32, [self.args.batch_size], name='qa_data') self.target = tf.placeholder(tf.float32, [self.args.batch_size], name='target') # Embedding to [batch size, seq_len, memory key state dim] q_embed_data = tf.nn.embedding_lookup(self.q_embed_mtx, self.q_data) # List of [batch size, 1, memory key state dim] with 'seq_len' elements #print('Q_embedding shape : %s' % q_embed_data.get_shape()) #slice_q_embed_data = tf.split(q_embed_data, self.args.seq_len, 1) #print(len(slice_q_embed_data), type(slice_q_embed_data), slice_q_embed_data[0].get_shape()) # Embedding to [batch size, seq_len, memory value state dim] qa_embed_data = tf.nn.embedding_lookup(self.qa_embed_mtx, self.qa_data) #print('QA_embedding shape: %s' % qa_embed_data.get_shape()) # List of [batch size, 1, memory value state dim] with 'seq_len' elements #slice_qa_embed_data = tf.split(qa_embed_data, self.args.seq_len, 1) #prediction = list() #reuse_flag = False # k_t : [batch size, memory key state dim] #q = tf.squeeze(slice_q_embed_data[i], 1) # Attention, [batch size, memory size] self.correlation_weight = self.memory.attention(q_embed_data) # Read process, [batch size, memory value state dim] self.read_content = self.memory.read(self.correlation_weight) # Write process, [batch size, memory size, memory value state dim] # qa : [batch size, memory value state dim] #qa = tf.squeeze(slice_qa_embed_data[i], 1) # Only last time step value is necessary self.new_memory_value = self.memory.write(self.correlation_weight, qa_embed_data, reuse=reuse_flag) mastery_level_prior_difficulty = tf.concat([self.read_content, q_embed_data], 1) # f_t summary_vector = tf.tanh(operations.linear(mastery_level_prior_difficulty, self.args.final_fc_dim, name='Summary_Vector', reuse=reuse_flag)) # p_t pred_logits = operations.linear(summary_vector, 1, name='Prediction', reuse=reuse_flag) return pred_logits def build_optimizer(self): print('Building optimizer') # 'seq_len' means question sequences self.q_data_seq = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len], name='q_data_seq') self.qa_data_seq = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len], name='qa_data_seq') self.target_seq = tf.placeholder(tf.float32, [self.args.batch_size, self.args.seq_len], name='target_seq') # Embedding to [batch size, seq_len, memory key state dim] #q_embed_data = tf.nn.embedding_lookup(q_embed_mtx, self.q_data_seq) # List of [batch size, 1, memory key state dim] with 'seq_len' elements #print('Q_embedding shape : %s' % q_embed_data.get_shape()) slice_q_data = tf.split(self.q_data_seq, self.args.seq_len, 1) #print(len(slice_q_embed_data), type(slice_q_embed_data), slice_q_embed_data[0].get_shape()) # Embedding to [batch size, seq_len, memory value state dim] #qa_embed_data = tf.nn.embedding_lookup(qa_embed_mtx, self.qa_data_seq) #print('QA_embedding shape: %s' % qa_embed_data.get_shape()) # List of [batch size, 1, memory value state dim] with 'seq_len' elements slice_qa_data = tf.split(self.qa_data_seq, self.args.seq_len, 1) prediction = list() reuse_flag = False # Logics for i in range(self.args.seq_len): # To reuse linear vectors q = tf.squeeze(slice_q_data[i], 1) # Attention, [batch size, memory size] qa = tf.squeeze(slice_qa_data[i], 1) # Only last time step value is necessary pred_logits = prediction.append(pred_logits) # 'prediction' : seq_len length list of [batch size ,1], make it [batch size, seq_len] tensor # tf.stack convert to [batch size, seq_len, 1] self.pred_logits = tf.reshape(tf.stack(prediction, axis=1), [self.args.batch_size, self.args.seq_len]) # Define loss : standard cross entropy loss, need to ignore '-1' label example # Make target/label 1-d array target_1d = tf.reshape(self.target_seq, [-1]) pred_logits_1d = tf.reshape(self.pred_logits, [-1]) index = tf.where(tf.not_equal(target_1d, tf.constant(-1., dtype=tf.float32))) # tf.gather(params, indices) : Gather slices from params according to indices filtered_target = tf.gather(target_1d, index) filtered_logits = tf.gather(pred_logits_1d, index) self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=filtered_logits, labels=filtered_target)) self.pred = tf.sigmoid(self.pred_logits) # Optimizer : SGD + MOMENTUM with learning rate decay self.global_step = tf.Variable(0, trainable=False) self.lr = tf.placeholder(tf.float32, [], name='learning_rate') # self.lr_decay = tf.train.exponential_decay(self.args.initial_lr, global_step=global_step, decay_steps=10000, decay_rate=0.667, staircase=True) # self.learning_rate = tf.maximum(lr, self.args.lr_lowerbound) optimizer = tf.train.MomentumOptimizer(self.lr, self.args.momentum) grads, vrbs = zip(*optimizer.compute_gradients(self.loss)) grad, _ = tf.clip_by_global_norm(grads, self.args.maxgradnorm) self.train_op = optimizer.apply_gradients(zip(grad, vrbs), global_step=self.global_step) # grad_clip = [(tf.clip_by_value(grad, -self.args.maxgradnorm, self.args.maxgradnorm), var) for grad, var in grads] self.tr_vrbs = tf.trainable_variables() for i in self.tr_vrbs: print(i.name) self.saver = tf.train.Saver() def create_model(self): # 'seq_len' means question sequences self.q_data_seq = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len], name='q_data_seq') self.qa_data_seq = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len], name='qa_data') self.target_seq = tf.placeholder(tf.float32, [self.args.batch_size, self.args.seq_len], name='target') ''' # Initialize Memory with tf.variable_scope('Memory'): init_memory_key = tf.get_variable('key', [self.args.memory_size, self.args.memory_key_state_dim], \ initializer=tf.truncated_normal_initializer(stddev=0.1)) init_memory_value = tf.get_variable('value', [self.args.memory_size,self.args.memory_value_state_dim], \ initializer=tf.truncated_normal_initializer(stddev=0.1)) # Broadcast memory value tensor to match [batch size, memory size, memory state dim] # First expand dim at axis 0 so that makes 'batch size' axis and tile it along 'batch size' axis # tf.tile(inputs, multiples) : multiples length must be thes saame as the number of dimensions in input # tf.stack takes a list and convert each element to a tensor init_memory_value = tf.tile(tf.expand_dims(init_memory_value, 0), tf.stack([self.args.batch_size, 1, 1])) print(init_memory_value.get_shape()) self.memory = DKVMN(self.args.memory_size, self.args.memory_key_state_dim, \ self.args.memory_value_state_dim, init_memory_key=init_memory_key, init_memory_value=init_memory_value, name='DKVMN') # Embedding to [batch size, seq_len, memory_state_dim(d_k or d_v)] with tf.variable_scope('Embedding'): # A q_embed_mtx = tf.get_variable('q_embed', [self.args.n_questions+1, self.args.memory_key_state_dim],\ initializer=tf.truncated_normal_initializer(stddev=0.1)) # B qa_embed_mtx = tf.get_variable('qa_embed', [2*self.args.n_questions+1, self.args.memory_value_state_dim], initializer=tf.truncated_normal_initializer(stddev=0.1)) ''' # Embedding to [batch size, seq_len, memory key state dim] q_embed_data = tf.nn.embedding_lookup(self.q_embed_mtx, self.q_data_seq) # List of [batch size, 1, memory key state dim] with 'seq_len' elements #print('Q_embedding shape : %s' % q_embed_data.get_shape()) slice_q_embed_data = tf.split(q_embed_data, self.args.seq_len, 1) #print(len(slice_q_embed_data), type(slice_q_embed_data), slice_q_embed_data[0].get_shape()) # Embedding to [batch size, seq_len, memory value state dim] qa_embed_data = tf.nn.embedding_lookup(self.qa_embed_mtx, self.qa_data_seq) #print('QA_embedding shape: %s' % qa_embed_data.get_shape()) # List of [batch size, 1, memory value state dim] with 'seq_len' elements slice_qa_embed_data = tf.split(qa_embed_data, self.args.seq_len, 1) prediction = list() reuse_flag = False # Logics for i in range(self.args.seq_len): # To reuse linear vectors if i != 0: reuse_flag = True # k_t : [batch size, memory key state dim] q = tf.squeeze(slice_q_embed_data[i], 1) # Attention, [batch size, memory size] self.correlation_weight = self.memory.attention(q) # Read process, [batch size, memory value state dim] self.read_content = self.memory.read(self.correlation_weight) # Write process, [batch size, memory size, memory value state dim] # qa : [batch size, memory value state dim] qa = tf.squeeze(slice_qa_embed_data[i], 1) # Only last time step value is necessary self.new_memory_value = self.memory.write(self.correlation_weight, qa, reuse=reuse_flag) mastery_level_prior_difficulty = tf.concat([self.read_content, q], 1) # f_t summary_vector = tf.tanh(operations.linear(mastery_level_prior_difficulty, self.args.final_fc_dim, name='Summary_Vector', reuse=reuse_flag)) # p_t pred_logits = operations.linear(summary_vector, 1, name='Prediction', reuse=reuse_flag) prediction.append(pred_logits) # 'prediction' : seq_len length list of [batch size ,1], make it [batch size, seq_len] tensor # tf.stack convert to [batch size, seq_len, 1] self.pred_logits = tf.reshape(tf.stack(prediction, axis=1), [self.args.batch_size, self.args.seq_len]) # Define loss : standard cross entropy loss, need to ignore '-1' label example # Make target/label 1-d array target_1d = tf.reshape(self.target_seq, [-1]) pred_logits_1d = tf.reshape(self.pred_logits, [-1]) index = tf.where(tf.not_equal(target_1d, tf.constant(-1., dtype=tf.float32))) # tf.gather(params, indices) : Gather slices from params according to indices filtered_target = tf.gather(target_1d, index) filtered_logits = tf.gather(pred_logits_1d, index) self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=filtered_logits, labels=filtered_target)) self.pred = tf.sigmoid(self.pred_logits) # Optimizer : SGD + MOMENTUM with learning rate decay self.global_step = tf.Variable(0, trainable=False) self.lr = tf.placeholder(tf.float32, [], name='learning_rate') # self.lr_decay = tf.train.exponential_decay(self.args.initial_lr, global_step=global_step, decay_steps=10000, decay_rate=0.667, staircase=True) # self.learning_rate = tf.maximum(lr, self.args.lr_lowerbound) optimizer = tf.train.MomentumOptimizer(self.lr, self.args.momentum) grads, vrbs = zip(*optimizer.compute_gradients(self.loss)) grad, _ = tf.clip_by_global_norm(grads, self.args.maxgradnorm) self.train_op = optimizer.apply_gradients(zip(grad, vrbs), global_step=self.global_step) # grad_clip = [(tf.clip_by_value(grad, -self.args.maxgradnorm, self.args.maxgradnorm), var) for grad, var in grads] self.tr_vrbs = tf.trainable_variables() for i in self.tr_vrbs: print(i.name) self.saver = tf.train.Saver() def train(self, train_q_data, train_qa_data, valid_q_data, valid_qa_data): # q_data, qa_data : [samples, seq_len] shuffle_index = np.random.permutation(train_q_data.shape[0]) q_data_shuffled = train_q_data[shuffle_index] qa_data_shuffled = train_qa_data[shuffle_index] training_step = train_q_data.shape[0] // self.args.batch_size self.sess.run(tf.global_variables_initializer()) if self.args.show: from utils import ProgressBar bar = ProgressBar(label, max=training_step) self.train_count = 0 if self.args.init_from: if self.load(): print('Checkpoint_loaded') else: print('No checkpoint') else: if os.path.exists(os.path.join(self.args.checkpoint_dir, self.model_dir)): try: shutil.rmtree(os.path.join(self.args.checkpoint_dir, self.model_dir)) shutil.rmtree(os.path.join(self.args.log_dir, self.model_dir+'.csv')) except(FileNotFoundError, IOError) as e: print('[Delete Error] %s - %s' % (e.filename, e.strerror)) best_valid_auc = 0 # Training for epoch in range(0, self.args.num_epochs): if self.args.show: bar.next() pred_list = list() target_list = list() epoch_loss = 0 learning_rate = tf.train.exponential_decay(self.args.initial_lr, global_step=self.global_step, decay_steps=self.args.anneal_interval*training_step, decay_rate=0.667, staircase=True) #print('Epoch %d starts with learning rate : %3.5f' % (epoch+1, self.sess.run(learning_rate))) for steps in range(training_step): # [batch size, seq_len] q_batch_seq = q_data_shuffled[steps*self.args.batch_size:(steps+1)*self.args.batch_size, :] qa_batch_seq = qa_data_shuffled[steps*self.args.batch_size:(steps+1)*self.args.batch_size, :] # qa : exercise index + answer(0 or 1)*exercies_number # right : 1, wrong : 0, padding : -1 target = qa_batch_seq[:,:] # Make integer type to calculate target target = target.astype(np.int) target_batch = (target - 1) // self.args.n_questions target_batch = target_batch.astype(np.float) feed_dict = {self.q_data_seq:q_batch_seq, self.qa_data_seq:qa_batch_seq, self.target_seq:target_batch, self.lr:self.args.initial_lr} #self.lr:self.sess.run(learning_rate) loss_, pred_, _, = self.sess.run([self.loss, self.pred, self.train_op], feed_dict=feed_dict) # Get right answer index # Make [batch size * seq_len, 1] right_target = np.asarray(target_batch).reshape(-1,1) right_pred = np.asarray(pred_).reshape(-1,1) # np.flatnonzero returns indices which is nonzero, convert it list right_index = np.flatnonzero(right_target != -1.).tolist() # Number of 'training_step' elements list with [batch size * seq_len, ] pred_list.append(right_pred[right_index]) target_list.append(right_target[right_index]) epoch_loss += loss_ #print('Epoch %d/%d, steps %d/%d, loss : %3.5f' % (epoch+1, self.args.num_epochs, steps+1, training_step, loss_)) if self.args.show: bar.finish() all_pred = np.concatenate(pred_list, axis=0) all_target = np.concatenate(target_list, axis=0) # Compute metrics self.auc = metrics.roc_auc_score(all_target, all_pred) # Extract elements with boolean index # Make '1' for elements higher than 0.5 # Make '0' for elements lower than 0.5 all_pred[all_pred > 0.5] = 1 all_pred[all_pred <= 0.5] = 0 self.accuracy = metrics.accuracy_score(all_target, all_pred) epoch_loss = epoch_loss / training_step print('Epoch %d/%d, loss : %3.5f, auc : %3.5f, accuracy : %3.5f' % (epoch+1, self.args.num_epochs, epoch_loss, self.auc, self.accuracy)) self.write_log(epoch=epoch+1, auc=self.auc, accuracy=self.accuracy, loss=epoch_loss, name='training_') valid_steps = valid_q_data.shape[0] // self.args.batch_size valid_pred_list = list() valid_target_list = list() for s in range(valid_steps): # Validation valid_q = valid_q_data[s*self.args.batch_size:(s+1)*self.args.batch_size, :] valid_qa = valid_qa_data[s*self.args.batch_size:(s+1)*self.args.batch_size, :] # right : 1, wrong : 0, padding : -1 valid_target = (valid_qa - 1) // self.args.n_questions valid_feed_dict = {self.q_data_seq : valid_q, self.qa_data_seq : valid_qa, self.target_seq : valid_target} valid_loss, valid_pred = self.sess.run([self.loss, self.pred], feed_dict=valid_feed_dict) # Same with training set valid_right_target = np.asarray(valid_target).reshape(-1,) valid_right_pred = np.asarray(valid_pred).reshape(-1,) valid_right_index = np.flatnonzero(valid_right_target != -1).tolist() valid_target_list.append(valid_right_target[valid_right_index]) valid_pred_list.append(valid_right_pred[valid_right_index]) all_valid_pred = np.concatenate(valid_pred_list, axis=0) all_valid_target = np.concatenate(valid_target_list, axis=0) valid_auc = metrics.roc_auc_score(all_valid_target, all_valid_pred) # For validation accuracy all_valid_pred[all_valid_pred > 0.5] = 1 all_valid_pred[all_valid_pred <= 0.5] = 0 valid_accuracy = metrics.accuracy_score(all_valid_target, all_valid_pred) print('Epoch %d/%d, valid auc : %3.5f, valid accuracy : %3.5f' %(epoch+1, self.args.num_epochs, valid_auc, valid_accuracy)) # Valid log self.write_log(epoch=epoch+1, auc=valid_auc, accuracy=valid_accuracy, loss=valid_loss, name='valid_') if valid_auc > best_valid_auc: print('%3.4f to %3.4f' % (best_valid_auc, valid_auc)) best_valid_auc = valid_auc best_epoch = epoch + 1 self.save(best_epoch) return best_epoch def test(self, test_q, test_qa): steps = test_q.shape[0] // self.args.batch_size self.sess.run(tf.global_variables_initializer()) if self.load(): print('CKPT Loaded') else: raise Exception('CKPT need') pred_list = list() target_list = list() for s in range(steps): test_q_batch = test_q[s*self.args.batch_size:(s+1)*self.args.batch_size, :] test_qa_batch = test_qa[s*self.args.batch_size:(s+1)*self.args.batch_size, :] target = test_qa_batch[:,:] target = target.astype(np.int) target_batch = (target - 1) // self.args.n_questions target_batch = target_batch.astype(np.float) feed_dict = {self.q_data_seq:test_q_batch, self.qa_data_seq:test_qa_batch, self.target_seq:target_batch} loss_, pred_ = self.sess.run([self.loss, self.pred], feed_dict=feed_dict) # Get right answer index # Make [batch size * seq_len, 1] right_target = np.asarray(target_batch).reshape(-1,1) right_pred = np.asarray(pred_).reshape(-1,1) # np.flatnonzero returns indices which is nonzero, convert it list right_index = np.flatnonzero(right_target != -1.).tolist() # Number of 'training_step' elements list with [batch size * seq_len, ] pred_list.append(right_pred[right_index]) target_list.append(right_target[right_index]) all_pred = np.concatenate(pred_list, axis=0) all_target = np.concatenate(target_list, axis=0) # Compute metrics all_pred[all_pred > 0.5] = 1 all_pred[all_pred <= 0.5] = 0 self.test_auc = metrics.roc_auc_score(all_target, all_pred) # Extract elements with boolean index # Make '1' for elements higher than 0.5 # Make '0' for elements lower than 0.5 self.test_accuracy = metrics.accuracy_score(all_target, all_pred) print('Test auc : %3.4f, Test accuracy : %3.4f' % (self.test_auc, self.test_accuracy)) @property def model_dir(self): return '{}_{}batch_{}epochs'.format(self.args.dataset, self.args.batch_size, self.args.num_epochs) def load(self): checkpoint_dir = os.path.join(self.args.checkpoint_dir, self.model_dir) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.train_count = int(ckpt_name.split('-')[-1]) self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) return True else: return False def save(self, global_step): model_name = 'DKVMN' checkpoint_dir = os.path.join(self.args.checkpoint_dir, self.model_dir) if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) self.saver.save(self.sess, os.path.join(checkpoint_dir, model_name), global_step=global_step) print('Save checkpoint at %d' % (global_step+1)) # Log file def write_log(self, auc, accuracy, loss, epoch, name='training_'): log_path = os.path.join(self.args.log_dir, name+self.model_dir+'.csv') if not os.path.exists(log_path): self.log_file = open(log_path, 'w') self.log_file.write('Epoch\tAuc\tAccuracy\tloss\n') else: self.log_file = open(log_path, 'a') self.log_file.write(str(epoch) + '\t' + str(auc) + '\t' + str(accuracy) + '\t' + str(loss) + '\n') self.log_file.flush()
def sym_gen(self): ### TODO input variable 'q_data' q_data = mx.sym.Variable('q_data', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size) ### TODO input variable 'qa_data' qa_data = mx.sym.Variable('qa_data', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size) ### TODO input variable 'target' target = mx.sym.Variable('target', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size) ### Initialize Memory init_memory_key = mx.sym.Variable('init_memory_key_weight') init_memory_value = mx.sym.Variable('init_memory_value', shape=(self.memory_size, self.memory_value_state_dim), init=mx.init.Normal(0.1)) # (self.memory_size, self.memory_value_state_dim) init_memory_value = mx.sym.broadcast_to(mx.sym.expand_dims(init_memory_value, axis=0), shape=(self.batch_size, self.memory_size, self.memory_value_state_dim)) mem = DKVMN(memory_size=self.memory_size, memory_key_state_dim=self.memory_key_state_dim, memory_value_state_dim=self.memory_value_state_dim, init_memory_key=init_memory_key, init_memory_value=init_memory_value, name="DKVMN") ### embedding q_data = mx.sym.BlockGrad(q_data) q_embed_data = mx.sym.Embedding(data=q_data, input_dim=self.n_question + 1, output_dim=self.q_embed_dim, name='q_embed') slice_q_embed_data = mx.sym.SliceChannel(q_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True) qa_data = mx.sym.BlockGrad(qa_data) qa_embed_data = mx.sym.Embedding(data=qa_data, input_dim=self.n_question * 2 + 1, output_dim=self.qa_embed_dim, name='qa_embed') slice_qa_embed_data = mx.sym.SliceChannel(qa_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True) # memory block value_read_content_l = [] # (seqlen, batch_size, memory_state_dim) input_embed_l = [] # (seqlen, batch_size, q_embed_dim) for i in range(self.seqlen): ## Attention q = slice_q_embed_data[i] correlation_weight = mem.attention(q) ## Read Process read_content = mem.read(correlation_weight) # Shape (batch_size, memory_state_dim) ### save intermedium data value_read_content_l.append(read_content) input_embed_l.append(q) ## Write Process qa = slice_qa_embed_data[i] new_memory_value = mem.write(correlation_weight, qa) # (batch_size * seqlen, memory_state_dim) all_read_value_content = mx.sym.Concat(*value_read_content_l, num_args=self.seqlen, dim=0) all_read_value_content = mx.sym.reshape(data=all_read_value_content, shape=(self.seqlen, self.batch_size, -1)) # (batch_size * seqlen, q_embed_dim) input_embed_content = mx.sym.Concat(*input_embed_l, num_args=self.seqlen, dim=0) input_embed_content = mx.sym.reshape(data=input_embed_content, shape=(self.seqlen, self.batch_size, -1)) combined_data = mx.sym.Concat(all_read_value_content, input_embed_content, num_args=2, dim=2) lstm = mx.gluon.rnn.LSTM(50, num_layers=1, layout='TNC', prefix='LSTM') # (seqlen, batch_size, final_fc_dim) output = lstm(combined_data) output = mx.sym.reshape(data=output, shape=(-3, 0)) # output = mx.sym.FullyConnected(data=output, num_hidden=self.final_fc_dim, name='fc') # output = mx.sym.Activation(data=output, act_type='tanh', name='fc_tanh') pred = mx.sym.FullyConnected(data=output, num_hidden=1, name="final_fc") pred_prob = logistic_regression_mask_output(data=mx.sym.Reshape(pred, shape=(-1,)), label=mx.sym.Reshape(data=target, shape=(-1,)), ignore_label=-1., name='final_pred') mx.sym.softmax_cross_entropy() net = mx.sym.Group([pred_prob]) return net
class MODEL(nn.Module): def __init__(self, n_question, batch_size, q_embed_dim, qa_embed_dim, memory_size, memory_key_state_dim, memory_value_state_dim, final_fc_dim, student_num=None): super(MODEL, self).__init__() self.n_question = n_question self.batch_size = batch_size self.q_embed_dim = q_embed_dim self.qa_embed_dim = qa_embed_dim self.memory_size = memory_size self.memory_key_state_dim = memory_key_state_dim self.memory_value_state_dim = memory_value_state_dim self.final_fc_dim = final_fc_dim self.student_num = student_num self.input_embed_linear = nn.Linear(self.q_embed_dim, self.final_fc_dim, bias=True) self.read_embed_linear = nn.Linear(self.memory_value_state_dim + self.final_fc_dim, self.final_fc_dim, bias=True) self.predict_linear = nn.Linear(self.final_fc_dim, 1, bias=True) self.init_memory_key = nn.Parameter( torch.randn(self.memory_size, self.memory_key_state_dim)) # random nn.init.kaiming_normal_(self.init_memory_key) self.init_memory_value = nn.Parameter( torch.randn(self.memory_size, self.memory_value_state_dim)) # nn.init.kaiming_normal_(self.init_memory_value) self.mem = DKVMN( memory_size=self.memory_size, memory_key_state_dim=self. memory_key_state_dim, # memory_key 初始化后不变化 memory_value_state_dim=self.memory_value_state_dim, init_memory_key=self.init_memory_key) # 不断 write 更新memo value memory_value = nn.Parameter( torch.cat([ self.init_memory_value.unsqueeze(0) for _ in range(batch_size) ], 0).data) self.mem.init_value_memory(memory_value) self.q_embed = nn.Embedding(self.n_question + 1, self.q_embed_dim, padding_idx=0) self.qa_embed = nn.Embedding(2 * self.n_question + 1, self.qa_embed_dim, padding_idx=0) def init_params(self): nn.init.kaiming_normal_(self.predict_linear.weight) nn.init.kaiming_normal_(self.read_embed_linear.weight) nn.init.constant_(self.read_embed_linear.bias, 0) nn.init.constant_(self.predict_linear.bias, 0) # nn.init.constant_(self.input_embed_linear.bias, 0) # nn.init.normal(self.input_embed_linear.weight, std=0.02) def init_embeddings(self): nn.init.kaiming_normal_(self.q_embed.weight) nn.init.kaiming_normal_(self.qa_embed.weight) def forward(self, q_data, qa_data, target, student_id=None): batch_size = q_data.shape[0] seqlen = q_data.shape[1] q_embed_data = self.q_embed(q_data) qa_embed_data = self.qa_embed(qa_data) memory_value = nn.Parameter( torch.cat([ self.init_memory_value.unsqueeze(0) for _ in range(batch_size) ], 0).data) # memory: 32, 20, 200 self.mem.init_value_memory(memory_value) slice_q_data = torch.chunk(q_data, seqlen, 1) slice_q_embed_data = torch.chunk(q_embed_data, seqlen, 1) slice_qa_embed_data = torch.chunk(qa_embed_data, seqlen, 1) value_read_content_l = [] input_embed_l = [] # new_memory_value_l = [] predict_logs = [] for i in range(seqlen): ## Attention q = slice_q_embed_data[i].squeeze(1) correlation_weight = self.mem.attention(q) if_memory_write = slice_q_data[i].squeeze(1).ge(1) # q if_memory_write = utils.varible( torch.FloatTensor(if_memory_write.data.tolist()), 1) ## Read Process read_content = self.mem.read(correlation_weight) value_read_content_l.append(read_content) input_embed_l.append(q) ## Write Process qa = slice_qa_embed_data[i].squeeze(1) new_memory_value = self.mem.write( correlation_weight, qa, if_memory_write) # 直接把200个seq最后一次更新的memory_value作为知识水平向量输出 # new_memory_value_l.append(new_memory_value.squ) # read_content_embed = torch.tanh(self.read_embed_linear(torch.cat([read_content, q], 1))) # pred = self.predict_linear(read_content_embed) # predict_logs.append(pred) all_read_value_content = torch.cat( [value_read_content_l[i].unsqueeze(1) for i in range(seqlen)], 1) # 将每一个习题的内容拼接起来 input_embed_content = torch.cat( [input_embed_l[i].unsqueeze(1) for i in range(seqlen)], 1) # input_embed_content = input_embed_content.view(batch_size * seqlen, -1) # input_embed_content = torch.tanh(self.input_embed_linear(input_embed_content)) # input_embed_content = input_embed_content.view(batch_size, seqlen, -1) predict_input = torch.cat( [all_read_value_content, input_embed_content], 2) read_content_embed = torch.tanh( self.read_embed_linear(predict_input.view(batch_size * seqlen, -1))) pred = self.predict_linear(read_content_embed) # predicts = torch.cat([predict_logs[i] for i in range(seqlen)], 1) target_1d = target # [batch_size * seq_len, 1] mask = target_1d.ge(0) # [batch_size * seq_len, 1] # pred_1d = predicts.view(-1, 1) # [batch_size * seq_len, 1] pred_1d = pred.view(-1, 1) # [batch_size * seq_len, 1] filtered_pred = torch.masked_select(pred_1d, mask) filtered_target = torch.masked_select(target_1d, mask) # memory_value = torch.masked_select(new_memory_value.view(batch_size * seqlen, -1), mask) loss = torch.nn.functional.binary_cross_entropy_with_logits( filtered_pred, filtered_target) return loss, torch.sigmoid( filtered_pred), filtered_target, new_memory_value
class DeepIRTModel(object): def __init__(self, args, sess, name="KT"): self.args = args self.sess = sess self.name = name self.create_model() def create_model(self): self._create_placeholder() self._influence() self._create_loss() self._create_optimizer() self._add_summary() def _create_placeholder(self): logger.info("Initializing Placeholder") self.q_data = tf.placeholder(tf.int32, [self.args.batch_size, self.args.seq_len], name='q_data') self.qa_data = tf.placeholder( tf.int32, [self.args.batch_size, self.args.seq_len], name='qa_data') self.label = tf.placeholder(tf.float32, [self.args.batch_size, self.args.seq_len], name='label') def _influence(self): # Initialize Memory logger.info("Initializing Key and Value Memory") with tf.variable_scope("Memory"): init_key_memory = tf.get_variable( 'key_memory_matrix', [self.args.memory_size, self.args.key_memory_state_dim], initializer=tf.truncated_normal_initializer(stddev=0.1)) init_value_memory = tf.get_variable( 'value_memory_matrix', [self.args.memory_size, self.args.value_memory_state_dim], initializer=tf.truncated_normal_initializer(stddev=0.1)) # Boardcast value-memory matrix to Shape (batch_size, memory_size, memory_value_state_dim) init_value_memory = tf.tile( # tile the number of value-memory by the number of batch tf.expand_dims(init_value_memory, 0), # make the batch-axis tf.stack([self.args.batch_size, 1, 1])) logger.debug("Shape of init_value_memory = {}".format( init_value_memory.get_shape())) logger.debug("Shape of init_key_memory = {}".format( init_key_memory.get_shape())) # Initialize DKVMN self.memory = DKVMN( memory_size=self.args.memory_size, key_memory_state_dim=self.args.key_memory_state_dim, value_memory_state_dim=self.args.value_memory_state_dim, init_key_memory=init_key_memory, init_value_memory=init_value_memory, name="DKVMN") # Initialize Embedding logger.info("Initializing Q and QA Embedding") with tf.variable_scope('Embedding'): q_embed_matrix = tf.get_variable( 'q_embed', [self.args.n_questions + 1, self.args.key_memory_state_dim], initializer=tf.truncated_normal_initializer(stddev=0.1)) qa_embed_matrix = tf.get_variable( 'qa_embed', [ 2 * self.args.n_questions + 1, self.args.value_memory_state_dim ], initializer=tf.truncated_normal_initializer(stddev=0.1)) # Embedding to Shape (batch size, seq_len, memory_state_dim(d_k or d_v)) logger.info("Initializing Embedding Lookup") q_embed_data = tf.nn.embedding_lookup(q_embed_matrix, self.q_data) qa_embed_data = tf.nn.embedding_lookup(qa_embed_matrix, self.qa_data) logger.debug("Shape of q_embed_data: {}".format( q_embed_data.get_shape())) logger.debug("Shape of qa_embed_data: {}".format( qa_embed_data.get_shape())) sliced_q_embed_data = tf.split(value=q_embed_data, num_or_size_splits=self.args.seq_len, axis=1) sliced_qa_embed_data = tf.split(value=qa_embed_data, num_or_size_splits=self.args.seq_len, axis=1) logger.debug("Shape of sliced_q_embed_data[0]: {}".format( sliced_q_embed_data[0].get_shape())) logger.debug("Shape of sliced_qa_embed_data[0]: {}".format( sliced_qa_embed_data[0].get_shape())) pred_z_values = list() student_abilities = list() question_difficulties = list() reuse_flag = False logger.info("Initializing Influence Procedure") for i in range(self.args.seq_len): # To reuse linear vectors if i != 0: reuse_flag = True # Get the query and content vector q = tf.squeeze(sliced_q_embed_data[i], 1) qa = tf.squeeze(sliced_qa_embed_data[i], 1) logger.debug("qeury vector q: {}".format(q)) logger.debug("content vector qa: {}".format(qa)) # Attention, correlation_weight: Shape (batch_size, memory_size) self.correlation_weight = self.memory.attention( embedded_query_vector=q) logger.debug("correlation_weight: {}".format( self.correlation_weight)) # Read process, read_content: (batch_size, value_memory_state_dim) self.read_content = self.memory.read( correlation_weight=self.correlation_weight) logger.debug("read_content: {}".format(self.read_content)) # Write process, new_memory_value: Shape (batch_size, memory_size, value_memory_state_dim) self.new_memory_value = self.memory.write(self.correlation_weight, qa, reuse=reuse_flag) logger.debug("new_memory_value: {}".format(self.new_memory_value)) # Build the feature vector -- summary_vector mastery_level_prior_difficulty = tf.concat([self.read_content, q], 1) self.summary_vector = layers.fully_connected( inputs=mastery_level_prior_difficulty, num_outputs=self.args.summary_vector_output_dim, scope='SummaryOperation', reuse=reuse_flag, activation_fn=tf.nn.tanh) logger.debug("summary_vector: {}".format(self.summary_vector)) # Calculate the student ability level from summary vector student_ability = layers.fully_connected( inputs=self.summary_vector, num_outputs=1, scope='StudentAbilityOutputLayer', reuse=reuse_flag, activation_fn=None) # Calculate the question difficulty level from the question embedding question_difficulty = layers.fully_connected( inputs=q, num_outputs=1, scope='QuestionDifficultyOutputLayer', reuse=reuse_flag, activation_fn=tf.nn.tanh) # Prediction pred_z_value = 3.0 * student_ability - question_difficulty pred_z_values.append(pred_z_value) student_abilities.append(student_ability) question_difficulties.append(question_difficulty) self.pred_z_values = tf.reshape(tf.stack( pred_z_values, axis=1), [self.args.batch_size, self.args.seq_len]) self.student_abilities = tf.reshape( tf.stack(student_abilities, axis=1), [self.args.batch_size, self.args.seq_len]) self.question_difficulties = tf.reshape( tf.stack(question_difficulties, axis=1), [self.args.batch_size, self.args.seq_len]) logger.debug("Shape of pred_z_values: {}".format(self.pred_z_values)) logger.debug("Shape of student_abilities: {}".format( self.student_abilities)) logger.debug("Shape of question_difficulties: {}".format( self.question_difficulties)) def _create_loss(self): logger.info("Initializing Loss Function") # convert into 1D label_1d = tf.reshape(self.label, [-1]) pred_z_values_1d = tf.reshape(self.pred_z_values, [-1]) student_abilities_1d = tf.reshape(self.student_abilities, [-1]) question_difficulties_1d = tf.reshape(self.question_difficulties, [-1]) # find the label index that is not masking index = tf.where( tf.not_equal(label_1d, tf.constant(-1., dtype=tf.float32))) # masking filtered_label = tf.gather(label_1d, index) filtered_z_values = tf.gather(pred_z_values_1d, index) filtered_student_abilities = tf.gather(student_abilities_1d, index) filtered_question_difficulties = tf.gather(question_difficulties_1d, index) logger.debug("Shape of filtered_label: {}".format(filtered_label)) logger.debug( "Shape of filtered_z_values: {}".format(filtered_z_values)) logger.debug("Shape of filtered_student_abilities: {}".format( filtered_student_abilities)) logger.debug("Shape of filtered_question_difficulties: {}".format( filtered_question_difficulties)) if self.args.use_ogive_model: # make prediction using normal ogive model dist = tfd.Normal(loc=0.0, scale=1.0) self.pred = dist.cdf(pred_z_values_1d) filtered_pred = dist.cdf(filtered_z_values) else: self.pred = tf.math.sigmoid(pred_z_values_1d) filtered_pred = tf.math.sigmoid(filtered_z_values) # convert the prediction probability to logit, i.e., log(p/(1-p)) epsilon = 1e-6 clipped_filtered_pred = tf.clip_by_value(filtered_pred, epsilon, 1. - epsilon) filtered_logits = tf.log(clipped_filtered_pred / (1 - clipped_filtered_pred)) # cross entropy loss cross_entropy = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=filtered_logits, labels=filtered_label)) self.loss = cross_entropy def _create_optimizer(self): with tf.variable_scope('Optimizer'): self.optimizer = tf.train.AdamOptimizer( learning_rate=self.args.learning_rate) gvs = self.optimizer.compute_gradients(self.loss) clipped_gvs = [(tf.clip_by_norm(grad, self.args.max_grad_norm), var) for grad, var in gvs] self.train_op = self.optimizer.apply_gradients(clipped_gvs) def _add_summary(self): tf.summary.scalar('Loss', self.loss) self.tensorboard_writer = tf.summary.FileWriter( logdir=self.args.tensorboard_dir, graph=self.sess.graph) model_vars = tf.trainable_variables() total_size = 0 total_bytes = 0 model_msg = "" for var in model_vars: # if var.num_elements() is None or [] assume size 0. var_size = var.get_shape().num_elements() or 0 var_bytes = var_size * var.dtype.size total_size += var_size total_bytes += var_bytes model_msg += ' '.join([ var.name, tensor_description(var), '[%d, bytes: %d]' % (var_size, var_bytes) ]) model_msg += '\n' model_msg += 'Total size of variables: %d \n' % total_size model_msg += 'Total bytes of variables: %d \n' % total_bytes logger.info(model_msg)
def sym_gen(self): ### TODO input variable 'q_data' q_data = mx.sym.Variable('q_data', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size) ### TODO input variable 'qa_data' qa_data = mx.sym.Variable('qa_data', shape=(self.seqlen, self.batch_size)) # (seqlen, batch_size) ### TODO input variable 'target' target = mx.sym.Variable('target', shape=(self.seqlen, self.batch_size)) #(seqlen, batch_size) ### Initialize Memory init_memory_key = mx.sym.Variable('init_memory_key_weight') init_memory_value = mx.sym.Variable('init_memory_value', shape=(self.memory_size, self.memory_value_state_dim), init=mx.init.Normal(0.1)) # (self.memory_size, self.memory_value_state_dim) init_memory_value = mx.sym.broadcast_to(mx.sym.expand_dims(init_memory_value, axis=0), shape=(self.batch_size, self.memory_size, self.memory_value_state_dim)) mem = DKVMN(memory_size=self.memory_size, memory_key_state_dim=self.memory_key_state_dim, memory_value_state_dim=self.memory_value_state_dim, init_memory_key=init_memory_key, init_memory_value=init_memory_value, name="DKVMN") ### embedding q_data = mx.sym.BlockGrad(q_data) q_embed_data = mx.sym.Embedding(data=q_data, input_dim=self.n_question+1, output_dim=self.q_embed_dim, name='q_embed') slice_q_embed_data = mx.sym.SliceChannel(q_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True) qa_data = mx.sym.BlockGrad(qa_data) qa_embed_data = mx.sym.Embedding(data=qa_data, input_dim=self.n_question*2+1, output_dim=self.qa_embed_dim, name='qa_embed') slice_qa_embed_data = mx.sym.SliceChannel(qa_embed_data, num_outputs=self.seqlen, axis=0, squeeze_axis=True) value_read_content_l = [] input_embed_l = [] for i in range(self.seqlen): ## Attention q = slice_q_embed_data[i] correlation_weight = mem.attention(q) ## Read Process read_content = mem.read(correlation_weight) #Shape (batch_size, memory_state_dim) ### save intermedium data value_read_content_l.append(read_content) input_embed_l.append(q) ## Write Process qa = slice_qa_embed_data[i] new_memory_value = mem.write(correlation_weight, qa) all_read_value_content = mx.sym.Concat(*value_read_content_l, num_args=self.seqlen, dim=0) input_embed_content = mx.sym.Concat(*input_embed_l, num_args=self.seqlen, dim=0) input_embed_content = mx.sym.FullyConnected(data=input_embed_content, num_hidden=50, name="input_embed_content") input_embed_content = mx.sym.Activation(data=input_embed_content, act_type='tanh', name="input_embed_content_tanh") read_content_embed = mx.sym.FullyConnected(data=mx.sym.Concat(all_read_value_content, input_embed_content, num_args=2, dim=1), num_hidden=self.final_fc_dim, name="read_content_embed") read_content_embed = mx.sym.Activation(data=read_content_embed, act_type='tanh', name="read_content_embed_tanh") pred = mx.sym.FullyConnected(data=read_content_embed, num_hidden=1, name="final_fc") pred_prob = logistic_regression_mask_output(data=mx.sym.Reshape(pred, shape=(-1, )), label=mx.sym.Reshape(data=target, shape=(-1,)), ignore_label=-1., name='final_pred') return mx.sym.Group([pred_prob])