コード例 #1
0
def main():
    '''
    Training and evaluation of the model.
    '''
    print('Training starts...')
    for epoch in range(num_of_epoch):
        print('\nEpoch', epoch + 1)
        # log the start time of the epoch
        start = time.time()
        # set the models in training mode
        clstm.train()
        policy_s.train()
        policy_n.train()
        policy_c.train()
        # reset the count of reread_or_skim_times
        reread_or_skim_times = 0
        policy_loss_sum = []
        encoder_loss_sum = []
        baseline_value_batch = []
        for index, train in enumerate(train_iterator):
            label = train.label.to(
                torch.long
            )  # for cross entropy loss, the long type is required
            text = train.text.view(CHUNCK_SIZE, BATCH_SIZE,
                                   CHUNCK_SIZE)  # transform 1*400 to 20*1*20
            curr_step = 0  # the position of the current chunk
            h_0 = torch.zeros([1, 1, 128]).to(device)  # run on GPU
            c_0 = torch.zeros([1, 1, 128]).to(device)
            count = 0  # maximum skim/reread time: 5
            baseline_value_ep = []
            saved_log_probs = []  # for the use of policy gradient update
            # collect the computational costs for every time step
            cost_ep = []
            while curr_step < CHUNCK_SIZE and count < 5:
                # Loop until a text can be classified or currstep is up to 20 or count reach the maximum i.e. 5.
                # update count
                count += 1
                # pass the input through cnn-lstm and policy s
                text_input = text[curr_step]  # text_input 1*20
                ht, ct = clstm(text_input, h_0, c_0)  # 1 * 128
                # separate the value which is the input of value net
                ht_ = ht.clone().detach().requires_grad_(True)
                # compute a baseline value for the value network
                bi = value_net(ht_)
                # 1 * 1 * 128, next input of lstm
                h_0 = ht.unsqueeze(0)
                c_0 = ct
                # draw a stop decision
                stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
                stop_decision = stop_decision.item()
                if stop_decision == 1:  # classify
                    break
                else:
                    reread_or_skim_times += 1
                    # draw an action (reread or skip)
                    step, log_prob_n = sample_policy_n(ht, policy_n)
                    curr_step += int(step)  # reread or skip
                    if curr_step < CHUNCK_SIZE and count < 5:
                        # If the code can still execute the next loop, it is not the last time step.
                        cost_ep.append(clstm_cost + s_cost + n_cost)
                        # add the baseline value
                        baseline_value_ep.append(bi)
                        # add the log prob for the current actions
                        saved_log_probs.append(log_prob_s + log_prob_n)
            # draw a predicted label
            output_c = policy_c(ht)
            # cross entrpy loss input shape: input(N, C), target(N)
            loss = criterion(output_c, label)  # positive value
            # draw a predicted label
            pred_label, log_prob_c = sample_policy_c(output_c)
            if stop_decision == 1:
                # add the cost of the last time step
                cost_ep.append(clstm_cost + s_cost + c_cost)
                saved_log_probs.append(log_prob_s + log_prob_c)
            else:
                # add the cost of the last time step
                cost_ep.append(clstm_cost + s_cost + c_cost + n_cost)
                # At the moment, the probability of drawing a stop decision is 1,
                # so its log probability is zero which can be ignored in th sum.
                saved_log_probs.append(log_prob_c.unsqueeze(0))
            # add the baseline value
            baseline_value_ep.append(bi)
            # add the cross entropy loss
            encoder_loss_sum.append(loss)
            # compute the policy losses and value losses for the current episode
            policy_loss_ep, value_losses = compute_policy_value_losses(
                cost_ep, loss, saved_log_probs, baseline_value_ep, alpha,
                gamma)
            policy_loss_sum.append(torch.cat(policy_loss_ep).sum())
            baseline_value_batch.append(torch.cat(value_losses).sum())
            # update gradients
            if (index + 1
                ) % batch_sz == 0:  # take the average of samples, backprop
                finish_episode(policy_loss_sum, encoder_loss_sum,
                               baseline_value_batch)
                del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:]

            if (index + 1) % 2000 == 0:
                print(f'\n current episode: {index + 1}')
                # log the current position of the text which the agent has gone through
                print('curr_step: ', curr_step)
                # log the sum of the rereading and skimming times
                print(f'current reread_or_skim_times: {reread_or_skim_times}')

        print('Epoch time elapsed: %.2f s' % (time.time() - start))
        print('reread_or_skim_times in this epoch:', reread_or_skim_times)
        count_all, count_correct = evaluate(clstm, policy_s, policy_n,
                                            policy_c, valid_iterator)
        print('Epoch: %s, Accuracy on the validation set: %.2f' %
              (epoch + 1, count_correct / count_all))
        count_all, count_correct = evaluate(clstm, policy_s, policy_n,
                                            policy_c, train_iterator)
        print('Epoch: %s, Accuracy on the training set: %.2f' %
              (epoch + 1, count_correct / count_all))

    print('Compute the accuracy on the testing set...')
    count_all, count_correct = evaluate(clstm, policy_s, policy_n, policy_c,
                                        test_iterator)
    print('Accuracy on the testing set: %.2f' % (count_correct / count_all))
コード例 #2
0
def main():
    '''
    Training and evaluation of the model.
    '''
    print('Training starts...')
    for epoch in range(num_of_epoch):
        print('\nEpoch', epoch+1)
        # log the start time of the epoch
        start = time.time()
        clstm.train()
        policy_c.train()
        policy_s.train()
        policy_loss_sum = []
        encoder_loss_sum = []
        baseline_value_batch = []
        for index, train in enumerate(train_iterator):
            label = train.label.to(torch.long) # 64
            text = train.text.view(CHUNCK_SIZE, BATCH_SIZE, CHUNCK_SIZE) # transform 1*400 to 20*1*20
            curr_step = 0
            # set up the initial input for lstm
            h_0 = torch.zeros([1,1,128]).to(device) 
            c_0 = torch.zeros([1,1,128]).to(device)
            saved_log_probs = []
            baseline_value_ep = []
            cost_ep = []   # collect the computational costs for every time step
            while (curr_step < 20):
                '''
                loop until stop decision equals 1 
                or the whole text has been read
                '''
                # read a chunk
                text_input = text[curr_step]
                # hidden state
                ht, ct = clstm(text_input, h_0, c_0)  # 1 * 128
                h_0 = ht.unsqueeze(0).to(device)  # 1 * 1 * 128, next input of lstm
                c_0 = ct
                # compute a baseline value for the value network
                ht_ = ht.clone().detach().requires_grad_(True).to(device)
                bi = value_net(ht_)
                # draw a stop decision
                stop_decision, log_prob_s = sample_policy_s(ht, policy_s)
                stop_decision = stop_decision.item()
                if stop_decision == 1:
                    break
                else:
                    curr_step += 1
                    if curr_step < 20:
                        # If the code can still execute the next loop, it is not the last time step.
                        cost_ep.append(clstm_cost + s_cost)
                        # add the baseline value
                        saved_log_probs.append(log_prob_s)
                        baseline_value_ep.append(bi)
                    
            # add the baseline value at the last step
            baseline_value_ep.append(bi)
            cost_ep.append(clstm_cost + s_cost + c_cost)
            # output of classifier       
            output_c = policy_c(ht)  # classifier
            # compute cross entropy loss
            loss = criterion(output_c, label)
            encoder_loss_sum.append(loss)
            # draw a predicted label 
            pred_label, log_prob_c = sample_policy_c(output_c)
            saved_log_probs.append(log_prob_c.unsqueeze(0) + log_prob_s)
            # compute the policy losses and value losses for the current episode
            policy_loss_ep, value_losses = compute_policy_value_losses(cost_ep, loss, saved_log_probs, baseline_value_ep, alpha, gamma)
            policy_loss_sum.append(torch.cat(policy_loss_ep).sum())
            baseline_value_batch.append(torch.cat(value_losses).sum())
            # Backward and optimize
            if (index + 1) % batch_sz == 0:
                finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch)
                del policy_loss_sum[:], encoder_loss_sum[:], baseline_value_batch[:]  
            
            # print log
            if (index + 1) % 2000 == 0:
                print(f'\n current episode: {index + 1}')
                # log the current position of the text which the agent has gone through
                print('curr_step: ', curr_step)
        print('Epoch time elapsed: %.2f s' % (time.time() - start))
        count_all, count_correct = evaluate_earlystop(clstm, policy_s, policy_c, valid_iterator)
        print('Epoch: %s, Accuracy on the validation set: %.2f' % (epoch + 1, count_correct / count_all))
        count_all, count_correct = evaluate_earlystop(clstm, policy_s, policy_c, train_iterator)
        print('Epoch: %s, Accuracy on the training set: %.2f' % (epoch + 1, count_correct / count_all))
    
    print('Compute the accuracy on the testing set...')
    count_all, count_correct = evaluate_earlystop(clstm, policy_s, policy_c, test_iterator)
    print('Accuracy on the testing set: %.2f' % (count_correct / count_all))