Exemplo n.º 1
0
def run_test(test_data, net, rev_emb_dict, end_token, device="cuda"):
    argmax_reward_sum = 0.0
    argmax_reward_count = 0.0
    # p1 is one sentence, p2 is sentence list.
    for p1, p2 in test_data:
        # Transform sentence to padded embeddings.
        input_seq = net.pack_input(p1, net.emb, device)
        # Get hidden states from encoder.
        # enc = net.encode(input_seq)
        context, enc = net.encode_context(input_seq)
        # Decode sequence by feeding predicted token to the net again. Act greedily.
        # Return N*outputvocab, N output token indices.
        _, tokens = net.decode_chain_argmax(enc,
                                            net.emb(beg_token),
                                            seq_len=data.MAX_TOKENS,
                                            context=context[0],
                                            stop_at_token=end_token)
        # Show what the output action sequence is.
        action_tokens = []
        for temp_idx in tokens:
            if temp_idx in rev_emb_dict and rev_emb_dict.get(
                    temp_idx) != '#END':
                action_tokens.append(str(rev_emb_dict.get(temp_idx)).upper())
        # Using 0-1 reward to compute accuracy.
        reward = utils.calc_True_Reward_webqsp_novar(action_tokens, p2, False)
        # reward = random.random()
        argmax_reward_sum += float(reward)
        argmax_reward_count += 1
    if argmax_reward_count == 0:
        return 0.0
    else:
        return float(argmax_reward_sum) / float(argmax_reward_count)
def run_test_true_reward(test_data, net, rev_emb_dict, end_token, device="cuda"):
    argmax_reward_sum = 0.0
    argmax_reward_count = 0.0

    # BEGIN token
    beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device)
    beg_token = beg_token.cuda()

    # p1 is one sentence, p2 is sentence list.
    for p1, p2 in test_data:
        p_list = [(p1, p2)]
        input_ids, attention_masks = tokenizer_encode(tokenizer, p_list, rev_emb_dict, device, max_tokens)
        output, output_hidden_states = net.bert_encode(input_ids, attention_masks)
        context, enc = output_hidden_states, (output.unsqueeze(0), output.unsqueeze(0))
        input_seq = net.pack_input(p1, net.emb, device)
        # Return logits (N*outputvocab), res_tokens (1*N)
        # Always use the first token in input sequence, which is '#BEG' as the initial input of decoder.
        # The maximum length of the output is defined in class libbots.data.
        _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1],
                                            seq_len=data.MAX_TOKENS,
                                            context=context[0],
                                            stop_at_token=end_token)
        action_tokens = []
        for temp_idx in tokens:
            if temp_idx in rev_emb_dict and rev_emb_dict.get(temp_idx) != '#END':
                action_tokens.append(str(rev_emb_dict.get(temp_idx)).upper())
        # Using 0-1 reward to compute accuracy.
        if args.dataset == "csqa":
            argmax_reward_sum += float(utils.calc_True_Reward(action_tokens, p2, False))
        else:
            argmax_reward_sum += float(utils.calc_True_Reward_webqsp_novar(action_tokens, p2, False))

        argmax_reward_count += 1

    if argmax_reward_count == 0:
        return 0.0
    else:
        return float(argmax_reward_sum) / float(argmax_reward_count)
Exemplo n.º 3
0
                        item_enc,
                        beg_embedding,
                        data.MAX_TOKENS,
                        context[idx],
                        stop_at_token=end_token)
                    # Show what the output action sequence is.
                    action_tokens = []
                    for temp_idx in actions:
                        if temp_idx in rev_emb_dict and rev_emb_dict.get(
                                temp_idx) != '#END':
                            action_tokens.append(
                                str(rev_emb_dict.get(temp_idx)).upper())
                    # Get the highest BLEU score as baseline used in self-critic.
                    # If the last parameter is false, it means that the 0-1 reward is used to calculate the accuracy.
                    # Otherwise the adaptive reward is used.
                    argmax_reward = utils.calc_True_Reward_webqsp_novar(
                        action_tokens, qa_info, args.adaptive)
                    # argmax_reward = random.random()
                    true_reward_argmax.append(argmax_reward)

                    if args.NSM and 'pseudo_gold_program_reward' not in qa_info:
                        pseudo_program_tokens = str(
                            qa_info['pseudo_gold_program']).strip().split()
                        pseudo_program_reward = utils.calc_True_Reward_webqsp_novar(
                            pseudo_program_tokens, qa_info, args.adaptive)
                        qa_info[
                            'pseudo_gold_program_reward'] = pseudo_program_reward

                    # # In this case, the BLEU score is so high that it is not needed to train such case with RL.
                    # if not args.disable_skip and argmax_reward > 0.99:
                    #     skipped_samples += 1
                    #     continue
Exemplo n.º 4
0
        bleu = utils.calc_bleu_many(tokens, references)
        # print(tokens)
        # print(references)

        # Show what the output action sequence is.
        action_tokens = []
        for temp_idx in tokens:
            if temp_idx in rev_emb_dict and rev_emb_dict.get(
                    temp_idx) != '#END':
                action_tokens.append(str(rev_emb_dict.get(temp_idx)).upper())
        # Get the highest BLEU score as baseline used in self-critic.
        # If the last parameter is false, it means that the 0-1 reward is used to calculate the accuracy.
        # Otherwise the adaptive reward is used.

        true_reward_F1score = utils.calc_True_Reward_webqsp_novar(
            action_tokens, targets, False)
        print("true_reward_F1score", true_reward_F1score)
        sum_turereward_f1 += true_reward_F1score

        intersec = set(tokens).intersection(set(references[0]))

        if len(references) == 0:
            prec = 0.0
        else:
            prec = float(len(intersec)) / float(len(references[0]))
        rec = float(len(intersec)) / float(len(references[0]))
        if prec == 0 and rec == 0:
            f1 = 0
        else:
            sum_f1 += (2.0 * prec * rec) / (prec + rec)