def main(): ''' Training and evaluation of the model. ''' print('Training starts...') for epoch in range(num_of_epoch): print('\nEpoch', epoch + 1) # log the start time of the epoch start = time.time() # set the models in training mode clstm.train() policy_s.train() policy_n.train() policy_c.train() # reset the count of reread_or_skim_times reread_or_skim_times = 0 policy_loss_sum = [] encoder_loss_sum = [] baseline_value_batch = [] for index, train in enumerate(train_iterator): label = train.label.to( torch.long ) # for cross entropy loss, the long type is required text = train.text.view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20 curr_step = 0 # the position of the current chunk h_0 = torch.zeros([1, 1, 128]).to(device) # run on GPU c_0 = torch.zeros([1, 1, 128]).to(device) count = 0 # maximum skim/reread time: 5 baseline_value_ep = [] saved_log_probs = [] # for the use of policy gradient update # collect the computational costs for every time step cost_ep = [] while curr_step < CHUNCK_SIZE and count < 5: # Loop until a text can be classified or currstep is up to 20 or count reach the maximum i.e. 5. # update count count += 1 # pass the input through cnn-lstm and policy s text_input = text[curr_step] # text_input 1*20 ht, ct = clstm(text_input, h_0, c_0) # 1 * 128 # separate the value which is the input of value net ht_ = ht.clone().detach().requires_grad_(True) # compute a baseline value for the value network bi = value_net(ht_) # 1 * 1 * 128, next input of lstm h_0 = ht.unsqueeze(0) c_0 = ct # draw a stop decision stop_decision, log_prob_s = sample_policy_s(ht, policy_s) stop_decision = stop_decision.item() if stop_decision == 1: # classify break else: reread_or_skim_times += 1 # draw an action (reread or skip) step, log_prob_n = sample_policy_n(ht, policy_n) curr_step += int(step) # reread or skip if curr_step < CHUNCK_SIZE and count < 5: # If the code can still execute the next loop, it is not the last time step. cost_ep.append(clstm_cost + s_cost + n_cost) # add the baseline value baseline_value_ep.append(bi) # add the log prob for the current actions saved_log_probs.append(log_prob_s + log_prob_n) # draw a predicted label output_c = policy_c(ht) # cross entrpy loss input shape: input(N, C), target(N) loss = criterion(output_c, label) # positive value # draw a predicted label pred_label, log_prob_c = sample_policy_c(output_c) if stop_decision == 1: # add the cost of the last time step cost_ep.append(clstm_cost + s_cost + c_cost) saved_log_probs.append(log_prob_s + log_prob_c) else: # add the cost of the last time step cost_ep.append(clstm_cost + s_cost + c_cost + n_cost) # At the moment, the probability of drawing a stop decision is 1, # so its log probability is zero which can be ignored in th sum. saved_log_probs.append(log_prob_c.unsqueeze(0)) # add the baseline value baseline_value_ep.append(bi) # add the cross entropy loss encoder_loss_sum.append(loss) # compute the policy losses and value losses for the current episode policy_loss_ep, value_losses = compute_policy_value_losses( cost_ep, loss, saved_log_probs, baseline_value_ep, alpha, gamma) policy_loss_sum.append(torch.cat(policy_loss_ep).sum()) baseline_value_batch.append(torch.cat(value_losses).sum()) # update gradients if (index + 1 ) % batch_sz == 0: # take the average of samples, backprop finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch) del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:] if (index + 1) % 2000 == 0: print(f'\n current episode: {index + 1}') # log the current position of the text which the agent has gone through print('curr_step: ', curr_step) # log the sum of the rereading and skimming times print(f'current reread_or_skim_times: {reread_or_skim_times}') print('Epoch time elapsed: %.2f s' % (time.time() - start)) print('reread_or_skim_times in this epoch:', reread_or_skim_times) count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, valid_iterator) print('Epoch: %s, Accuracy on the validation set: %.2f' % (epoch + 1, count_correct / count_all)) count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, train_iterator) print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all)) print('Compute the accuracy on the testing set...') count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c, test_iterator) print('Accuracy on the testing set: %.2f' % (count_correct / count_all))
def main(): ''' Training and evaluation of the model. ''' print('Training starts...') for epoch in range(num_of_epoch): print('\nEpoch', epoch+1) # log the start time of the epoch start = time.time() clstm.train() policy_c.train() policy_s.train() policy_loss_sum = [] encoder_loss_sum = [] baseline_value_batch = [] for index, train in enumerate(train_iterator): label = train.label.to(torch.long) # 64 text = train.text.view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20 curr_step = 0 # set up the initial input for lstm h_0 = torch.zeros([1,1,128]).to(device) c_0 = torch.zeros([1,1,128]).to(device) saved_log_probs = [] baseline_value_ep = [] cost_ep = [] # collect the computational costs for every time step while (curr_step < 20): ''' loop until stop decision equals 1 or the whole text has been read ''' # read a chunk text_input = text[curr_step] # hidden state ht, ct = clstm(text_input, h_0, c_0) # 1 * 128 h_0 = ht.unsqueeze(0).to(device) # 1 * 1 * 128, next input of lstm c_0 = ct # compute a baseline value for the value network ht_ = ht.clone().detach().requires_grad_(True).to(device) bi = value_net(ht_) # draw a stop decision stop_decision, log_prob_s = sample_policy_s(ht, policy_s) stop_decision = stop_decision.item() if stop_decision == 1: break else: curr_step += 1 if curr_step < 20: # If the code can still execute the next loop, it is not the last time step. cost_ep.append(clstm_cost + s_cost) # add the baseline value saved_log_probs.append(log_prob_s) baseline_value_ep.append(bi) # add the baseline value at the last step baseline_value_ep.append(bi) cost_ep.append(clstm_cost + s_cost + c_cost) # output of classifier output_c = policy_c(ht) # classifier # compute cross entropy loss loss = criterion(output_c, label) encoder_loss_sum.append(loss) # draw a predicted label pred_label, log_prob_c = sample_policy_c(output_c) saved_log_probs.append(log_prob_c.unsqueeze(0) + log_prob_s) # compute the policy losses and value losses for the current episode policy_loss_ep, value_losses = compute_policy_value_losses(cost_ep, loss, saved_log_probs, baseline_value_ep, alpha, gamma) policy_loss_sum.append(torch.cat(policy_loss_ep).sum()) baseline_value_batch.append(torch.cat(value_losses).sum()) # Backward and optimize if (index + 1) % batch_sz == 0: finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch) del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:] # print log if (index + 1) % 2000 == 0: print(f'\n current episode: {index + 1}') # log the current position of the text which the agent has gone through print('curr_step: ', curr_step) print('Epoch time elapsed: %.2f s' % (time.time() - start)) count_all, count_correct = evaluate_earlystop(clstm, policy_s, policy_c, valid_iterator) print('Epoch: %s, Accuracy on the validation set: %.2f' % (epoch + 1, count_correct / count_all)) count_all, count_correct = evaluate_earlystop(clstm, policy_s, policy_c, train_iterator) print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all)) print('Compute the accuracy on the testing set...') count_all, count_correct = evaluate_earlystop(clstm, policy_s, policy_c, test_iterator) print('Accuracy on the testing set: %.2f' % (count_correct / count_all))