def run_test(test_data, net, end_token, device="cuda", rev_emb_dict=None, tokenizer=None, max_tokens=None): bleu_sum = 0.0 bleu_count = 0 for p1, p2 in test_data: p_list = [(p1, p2)] input_ids, attention_masks = tokenizer_encode(tokenizer, p_list, rev_emb_dict, device, max_tokens) output, output_hidden_states = net.bert_encode(input_ids, attention_masks) context, enc = output_hidden_states, (output.unsqueeze(0), output.unsqueeze(0)) input_seq = net.pack_input(p1, net.emb, device) # Return logits (N*outputvocab), res_tokens (1*N) # Always use the first token in input sequence, which is '#BEG' as the initial input of decoder. # The maximum length of the output is defined in class libbots.data. _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS, context=context[0], stop_at_token=end_token) bleu_sum += utils.calc_bleu(tokens, p2[1:]) bleu_count += 1 return bleu_sum / bleu_count
def run_test(test_data, net, end_token, device="cpu"): bleu_sum = 0.0 bleu_count = 0 for p1, p2 in test_data: input_seq = model.pack_input(p1, net.emb, device) enc = net.encode(input_seq) _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS, stop_at_token=end_token) bleu_sum += utils.calc_bleu(tokens, p2[1:]) bleu_count += 1 return bleu_sum / bleu_count
def run_test(test_data, net, end_token, device="cuda"): bleu_sum = 0.0 bleu_count = 0 for p1, p2 in test_data: input_seq = net.pack_input(p1, net.emb, device) # enc = net.encode(input_seq) context, enc = net.encode_context(input_seq) # Return logits (N*outputvocab), res_tokens (1*N) # Always use the first token in input sequence, which is '#BEG' as the initial input of decoder. # The maximum length of the output is defined in class libbots.data. _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS, context=context[0], stop_at_token=end_token) bleu_sum += utils.calc_bleu(tokens, p2[1:]) bleu_count += 1 return bleu_sum / bleu_count
for idx, out_seq in enumerate(out_seq_list): ref_indices = out_idx[idx][1:] enc_item = net.get_encoded_item(enc, idx) # teacher forcing做训练; if random.random() < TEACHER_PROB: r = net.decode_teacher(enc_item, out_seq) blue_temp = attention_model.seq_bleu(r, ref_indices) bleu_sum += blue_temp # Get predicted tokens. seq = torch.max(r.data, dim=1)[1] seq = seq.cpu().numpy() # argmax做训练; else: r, seq = net.decode_chain_argmax(enc_item, out_seq.data[0:1], len(ref_indices)) blue_temp = utils.calc_bleu(seq, ref_indices) bleu_sum += blue_temp net_results.append(r) net_targets.extend(ref_indices) bleu_count += 1 if not dial_shown: # data.decode_words transform IDs to tokens. ref_words = [utils.untokenize(data.decode_words(ref_indices, rev_emb_dict))] log.info("Reference: %s", " ~~|~~ ".join(ref_words)) log.info("Predicted: %s, bleu=%.4f", utils.untokenize(data.decode_words(seq, rev_emb_dict)), blue_temp) dial_shown = True results_v = torch.cat(net_results) results_v = results_v.cuda() targets_v = torch.LongTensor(net_targets).to(device) targets_v = targets_v.cuda()
optimiser.zero_grad() input_seq, out_seq_list, _, out_idx = model.pack_batch(batch, net.emb, device) enc = net.encode(input_seq) net_results = [] net_targets = [] for idx, out_seq in enumerate(out_seq_list): ref_indices = out_idx[idx][1:] enc_item = net.get_encoded_item(enc, idx) if random.random() < TEACHER_PROB: r = net.decode_teacher(enc_item, out_seq) bleu_sum += model.seq_bleu(r, ref_indices) else: r, seq = net.decode_chain_argmax(enc_item, out_seq.data[0:1], len(ref_indices)) bleu_sum += utils.calc_bleu(seq, ref_indices) net_results.append(r) net_targets.extend(ref_indices) bleu_count += 1 results_v = torch.cat(net_results) targets_v = torch.LongTensor(net_targets).to(device) loss_v = F.cross_entropy(results_v, targets_v) loss_v.backward() optimiser.step() losses.append(loss_v.item()) bleu = bleu_sum / bleu_count bleu_test = run_test(test_data, net, end_token, device) log.info("Epoch %d: mean loss %.3f, mean BLEU %.3f, test BLEU %.3f", epoch, np.mean(losses), bleu, bleu_test) writer.add_scalar("loss", np.mean(losses), epoch)