コード例 #1
0
def model_evaluation(model,
                     test_data,
                     tokenizer,
                     slot_meta,
                     epoch,
                     op_code='4',
                     is_gt_op=False,
                     is_gt_p_state=False,
                     is_gt_gen=False,
                     use_full_slot=False,
                     use_dt_only=False,
                     no_dial=False,
                     use_cls_only=False,
                     n_gpu=0,
                     submission=False,
                     use_wandb=False):

    device = torch.device('cuda' if n_gpu else 'cpu')

    model.eval()
    op2id = OP_SET[op_code]
    id2op = {v: k for k, v in op2id.items()}
    id2domain = {v: k for k, v in domain2id.items()}

    slot_turn_acc, joint_acc, slot_F1_pred, slot_F1_count = 0, 0, 0, 0
    final_joint_acc, final_count, final_slot_F1_pred, final_slot_F1_count = 0, 0, 0, 0
    op_acc, op_F1, op_F1_count = 0, {k: 0
                                     for k in op2id}, {k: 0
                                                       for k in op2id}
    all_op_F1_count = {k: 0 for k in op2id}

    tp_dic = {k: 0 for k in op2id}
    fn_dic = {k: 0 for k in op2id}
    fp_dic = {k: 0 for k in op2id}

    results = {}
    last_dialog_state = {}
    wall_times = []
    if submission:
        _submission = {}

    start_time = time.time()
    for di, i in enumerate(test_data):
        if (di + 1) % 1000 == 0:
            print("{:}, {:.1f}min".format(di, (time.time() - start_time) / 60))
            sys.stdout.flush()

        if i.turn_id == 0:
            last_dialog_state = {}

        if is_gt_p_state is False:
            i.last_dialog_state = deepcopy(last_dialog_state)
            i.make_instance(tokenizer, word_dropout=0.)
        else:  # ground-truth previous dialogue state
            last_dialog_state = deepcopy(i.gold_p_state)
            i.last_dialog_state = deepcopy(last_dialog_state)
            i.make_instance(tokenizer, word_dropout=0.)

        id2ds = {}
        for id, s in enumerate(i.slot_meta):
            k = s.split('-')
            # print(k)  # e.g. ['attraction', 'area']
            id2ds[id] = tokenizer.convert_tokens_to_ids(
                tokenizer.tokenize(' '.join(k + ['-'])))

        tensor_list = wrap_into_tensor(
            [i],
            pad_id=tokenizer.convert_tokens_to_ids(['[PAD]'])[0],
            slot_id=tokenizer.convert_tokens_to_ids(['[SLOT]'])[0])[:4]
        tensor_list = [t.to(device) for t in tensor_list]
        input_ids_p, segment_ids_p, input_mask_p, state_position_ids = tensor_list

        d_gold_op, _, _ = make_turn_label(slot_meta,
                                          last_dialog_state,
                                          i.gold_state,
                                          tokenizer,
                                          op_code,
                                          dynamic=True)
        gold_op_ids = torch.LongTensor([d_gold_op]).to(device)

        start = time.perf_counter()

        MAX_LENGTH = 9
        if n_gpu > 1:
            model.module.decoder.min_len = 1  # just ask the decoder to generate at least a token (notice that [SEP] is included)
        else:
            model.decoder.min_len = 1

        with torch.no_grad():
            # ground-truth state operation
            gold_op_inputs = gold_op_ids if is_gt_op else None

            if n_gpu > 1:
                d, s, generated = model.module.output(
                    input_ids_p,
                    segment_ids_p,
                    input_mask_p,
                    state_position_ids,
                    i.diag_len,
                    op_ids=gold_op_inputs,
                    gen_max_len=MAX_LENGTH,
                    use_full_slot=use_full_slot,
                    use_dt_only=use_dt_only,
                    diag_1_len=i.diag_1_len,
                    no_dial=no_dial,
                    use_cls_only=use_cls_only,
                    i_dslen_map=i.i_dslen_map)
            else:
                d, s, generated = model.output(input_ids_p,
                                               segment_ids_p,
                                               input_mask_p,
                                               state_position_ids,
                                               i.diag_len,
                                               op_ids=gold_op_inputs,
                                               gen_max_len=MAX_LENGTH,
                                               use_full_slot=use_full_slot,
                                               use_dt_only=use_dt_only,
                                               diag_1_len=i.diag_1_len,
                                               no_dial=no_dial,
                                               use_cls_only=use_cls_only,
                                               i_dslen_map=i.i_dslen_map)

        _, op_ids = s.view(-1, len(op2id)).max(-1)

        if is_gt_op:
            pred_ops = [id2op[a] for a in gold_op_ids[0].tolist()]
        else:
            pred_ops = [id2op[a] for a in op_ids.tolist()]
        gold_ops = [id2op[a] for a in d_gold_op]

        if is_gt_gen:
            # ground_truth generation
            gold_gen = {
                '-'.join(ii.split('-')[:2]): ii.split('-')[-1]
                for ii in i.gold_state
            }
        else:
            gold_gen = {}

        generated, last_dialog_state = postprocessing(slot_meta, pred_ops,
                                                      last_dialog_state,
                                                      generated, tokenizer,
                                                      op_code, gold_gen)

        # print(last_dialog_state)

        end = time.perf_counter()
        wall_times.append(end - start)
        pred_state = []
        for k, v in last_dialog_state.items():
            pred_state.append('-'.join([k, v]))

        if set(pred_state) == set(i.gold_state):
            joint_acc += 1
        key = str(i.id) + '_' + str(i.turn_id)
        results[key] = [pred_state, i.gold_state]
        if submission:
            key_sub = str(i.id) + '-' + str(i.turn_id)
            _submission[key_sub] = pred_state

        # Compute prediction slot accuracy
        temp_acc = compute_acc(set(i.gold_state), set(pred_state), slot_meta)
        slot_turn_acc += temp_acc

        # Compute prediction F1 score
        temp_f1, temp_r, temp_p, count = compute_prf(i.gold_state, pred_state)
        slot_F1_pred += temp_f1
        slot_F1_count += count

        # Compute operation accuracy
        temp_acc = sum(
            [1 if p == g else 0
             for p, g in zip(pred_ops, gold_ops)]) / len(pred_ops)
        op_acc += temp_acc

        if i.is_last_turn:
            final_count += 1
            if set(pred_state) == set(i.gold_state):
                final_joint_acc += 1

            final_slot_F1_pred += temp_f1
            final_slot_F1_count += count

        # Compute operation F1 score
        for p, g in zip(pred_ops, gold_ops):
            all_op_F1_count[g] += 1
            if p == g:
                tp_dic[g] += 1
                op_F1_count[g] += 1
            else:
                fn_dic[g] += 1
                fp_dic[p] += 1

    joint_acc_score = joint_acc / len(test_data)
    turn_acc_score = slot_turn_acc / len(test_data)
    slot_F1_score = slot_F1_pred / slot_F1_count
    op_acc_score = op_acc / len(test_data)
    final_joint_acc_score = final_joint_acc / final_count
    final_slot_F1_score = final_slot_F1_pred / final_slot_F1_count
    latency = np.mean(wall_times) * 1000
    op_F1_score = {}
    for k in op2id.keys():
        tp = tp_dic[k]
        fn = fn_dic[k]
        fp = fp_dic[k]
        precision = tp / (tp + fp) if (tp + fp) != 0 else 0
        recall = tp / (tp + fn) if (tp + fn) != 0 else 0
        F1 = 2 * precision * recall / float(precision + recall) if (
            precision + recall) != 0 else 0
        op_F1_score[k] = F1

    print("------------------------------")
    print('op_code: %s, is_gt_op: %s, is_gt_p_state: %s, is_gt_gen: %s' % \
          (op_code, str(is_gt_op), str(is_gt_p_state), str(is_gt_gen)))
    print("Epoch %d joint accuracy : " % epoch, joint_acc_score)
    print("Epoch %d slot turn accuracy : " % epoch, turn_acc_score)
    print("Epoch %d slot turn F1: " % epoch, slot_F1_score)
    print("Epoch %d op accuracy : " % epoch, op_acc_score)
    print("Epoch %d op F1 : " % epoch, op_F1_score)
    print("Epoch %d op hit count : " % epoch, op_F1_count)
    print("Epoch %d op all count : " % epoch, all_op_F1_count)
    print("Final Joint Accuracy : ", final_joint_acc_score)
    print("Final slot turn F1 : ", final_slot_F1_score)
    print("Latency Per Prediction : %f ms" % latency)
    print("-----------------------------\n")

    if submission:
        json.dump(
            _submission,
            open(f"{epoch}-output.csv", "w"),
            indent=2,
            ensure_ascii=False,
        )
        scores = {}
    else:
        json.dump(results, open('preds_%d.json' % epoch, 'w'))

        if use_wandb:
            wandb.log({
                "joint_goal_accuracy": joint_acc_score,
                "turn_slot_accuracy": turn_acc_score,
                "turn_slot_f1": slot_F1_score
            })

        per_domain_join_accuracy(results, slot_meta)

        scores = {
            'epoch': epoch,
            'joint_acc': joint_acc_score,
            'slot_acc': turn_acc_score,
            'slot_f1': slot_F1_score,
            'op_acc': op_acc_score,
            'op_f1': op_F1_score,
            'final_slot_f1': final_slot_F1_score
        }
    return scores
コード例 #2
0
def model_evaluation(model,
                     test_data,
                     tokenizer,
                     slot_meta,
                     epoch,
                     op_code='4',
                     is_gt_op=False,
                     is_gt_p_state=False,
                     is_gt_gen=False):
    model.eval()
    op2id = OP_SET[op_code]
    id2op = {v: k for k, v in op2id.items()}
    id2domain = {v: k for k, v in domain2id.items()}

    slot_turn_acc, joint_acc, slot_F1_pred, slot_F1_count = 0, 0, 0, 0
    final_joint_acc, final_count, final_slot_F1_pred, final_slot_F1_count = 0, 0, 0, 0
    op_acc, op_F1, op_F1_count = 0, {k: 0
                                     for k in op2id}, {k: 0
                                                       for k in op2id}
    all_op_F1_count = {k: 0 for k in op2id}

    tp_dic = {k: 0 for k in op2id}
    fn_dic = {k: 0 for k in op2id}
    fp_dic = {k: 0 for k in op2id}

    results = {}
    last_dialog_state = {}
    wall_times = []
    for di, i in enumerate(test_data):
        if i.turn_id == 0:
            last_dialog_state = {}

        if is_gt_p_state is False:
            i.last_dialog_state = deepcopy(last_dialog_state)
            i.make_instance(tokenizer, word_dropout=0.)
        else:  # ground-truth previous dialogue state
            last_dialog_state = deepcopy(i.gold_p_state)
            i.last_dialog_state = deepcopy(last_dialog_state)
            i.make_instance(tokenizer, word_dropout=0.)

        input_ids = torch.LongTensor([i.input_id]).to(device)
        input_mask = torch.LongTensor([i.input_mask]).to(device)
        segment_ids = torch.LongTensor([i.segment_id]).to(device)
        state_position_ids = torch.LongTensor([i.slot_position]).to(device)

        d_gold_op, _, _ = make_turn_label(slot_meta,
                                          last_dialog_state,
                                          i.gold_state,
                                          tokenizer,
                                          op_code,
                                          dynamic=True)
        gold_op_ids = torch.LongTensor([d_gold_op]).to(device)

        start = time.perf_counter()
        MAX_LENGTH = 9
        with torch.no_grad():
            # ground-truth state operation
            gold_op_inputs = gold_op_ids if is_gt_op else None
            d, s, g = model(input_ids=input_ids,
                            token_type_ids=segment_ids,
                            state_positions=state_position_ids,
                            attention_mask=input_mask,
                            max_value=MAX_LENGTH,
                            op_ids=gold_op_inputs)

        _, op_ids = s.view(-1, len(op2id)).max(-1)

        if g.size(1) > 0:
            generated = g.squeeze(0).max(-1)[1].tolist()
        else:
            generated = []

        if is_gt_op:
            pred_ops = [id2op[a] for a in gold_op_ids[0].tolist()]
        else:
            pred_ops = [id2op[a] for a in op_ids.tolist()]
        gold_ops = [id2op[a] for a in d_gold_op]

        if is_gt_gen:
            # ground_truth generation
            gold_gen = {
                '-'.join(ii.split('-')[:2]): ii.split('-')[-1]
                for ii in i.gold_state
            }
        else:
            gold_gen = {}
        generated, last_dialog_state = postprocessing(slot_meta, pred_ops,
                                                      last_dialog_state,
                                                      generated, tokenizer,
                                                      op_code, gold_gen)
        end = time.perf_counter()
        wall_times.append(end - start)
        pred_state = []
        for k, v in last_dialog_state.items():
            pred_state.append('-'.join([k, v]))

        if set(pred_state) == set(i.gold_state):
            joint_acc += 1
        key = str(i.id) + '_' + str(i.turn_id)
        results[key] = [pred_state, i.gold_state]

        # Compute prediction slot accuracy
        temp_acc = compute_acc(set(i.gold_state), set(pred_state), slot_meta)
        slot_turn_acc += temp_acc

        # Compute prediction F1 score
        temp_f1, temp_r, temp_p, count = compute_prf(i.gold_state, pred_state)
        slot_F1_pred += temp_f1
        slot_F1_count += count

        # Compute operation accuracy
        temp_acc = sum(
            [1 if p == g else 0
             for p, g in zip(pred_ops, gold_ops)]) / len(pred_ops)
        op_acc += temp_acc

        if i.is_last_turn:
            final_count += 1
            if set(pred_state) == set(i.gold_state):
                final_joint_acc += 1
            final_slot_F1_pred += temp_f1
            final_slot_F1_count += count

        # Compute operation F1 score
        for p, g in zip(pred_ops, gold_ops):
            all_op_F1_count[g] += 1
            if p == g:
                tp_dic[g] += 1
                op_F1_count[g] += 1
            else:
                fn_dic[g] += 1
                fp_dic[p] += 1

    joint_acc_score = joint_acc / len(test_data)
    turn_acc_score = slot_turn_acc / len(test_data)
    slot_F1_score = slot_F1_pred / slot_F1_count
    op_acc_score = op_acc / len(test_data)
    final_joint_acc_score = final_joint_acc / final_count
    final_slot_F1_score = final_slot_F1_pred / final_slot_F1_count
    latency = np.mean(wall_times) * 1000
    op_F1_score = {}
    for k in op2id.keys():
        tp = tp_dic[k]
        fn = fn_dic[k]
        fp = fp_dic[k]
        precision = tp / (tp + fp) if (tp + fp) != 0 else 0
        recall = tp / (tp + fn) if (tp + fn) != 0 else 0
        F1 = 2 * precision * recall / float(precision + recall) if (
            precision + recall) != 0 else 0
        op_F1_score[k] = F1

    print("------------------------------")
    print('op_code: %s, is_gt_op: %s, is_gt_p_state: %s, is_gt_gen: %s' % \
          (op_code, str(is_gt_op), str(is_gt_p_state), str(is_gt_gen)))
    print("Epoch %d joint accuracy : " % epoch, joint_acc_score)
    print("Epoch %d slot turn accuracy : " % epoch, turn_acc_score)
    print("Epoch %d slot turn F1: " % epoch, slot_F1_score)
    print("Epoch %d op accuracy : " % epoch, op_acc_score)
    print("Epoch %d op F1 : " % epoch, op_F1_score)
    print("Epoch %d op hit count : " % epoch, op_F1_count)
    print("Epoch %d op all count : " % epoch, all_op_F1_count)
    print("Final Joint Accuracy : ", final_joint_acc_score)
    print("Final slot turn F1 : ", final_slot_F1_score)
    print("Latency Per Prediction : %f ms" % latency)
    print("-----------------------------\n")
    json.dump(results, open('preds_%d.json' % epoch, 'w'))
    per_domain_join_accuracy(results, slot_meta)

    scores = {
        'epoch': epoch,
        'joint_acc': joint_acc_score,
        'slot_acc': turn_acc_score,
        'slot_f1': slot_F1_score,
        'op_acc': op_acc_score,
        'op_f1': op_F1_score,
        'final_slot_f1': final_slot_F1_score
    }
    return scores
コード例 #3
0
def inference_model(model, test_data, tokenizer, slot_meta, op_code='4'):
    model.eval()
    op2id = OP_SET[op_code]
    id2op = {v: k for k, v in op2id.items()}
    id2domain = {v: k for k, v in domain2id.items()}

    results = {}
    last_dialog_state = {}
    wall_times = []
    for di, i in enumerate(test_data):
        if (di+1) % 1000 == 0:
            print(f"{di+1}'s test data is been inferencing")

        if i.turn_id == 0:
            last_dialog_state = {}

        i.last_dialog_state = deepcopy(last_dialog_state)
        # print(di, last_dialog_state)
        i.make_instance(tokenizer, word_dropout=0.)

        input_ids = torch.LongTensor([i.input_id]).to(device)
        segment_ids = torch.LongTensor([i.segment_id]).to(device)
        state_position_ids = torch.LongTensor([i.slot_position]).to(device)
        input_mask = torch.LongTensor([i.input_mask]).to(device)
        # print(f"input_id : {input_ids}")
        # print(f"segment_id : {segment_ids}")
        # print(f"slot_position : {state_position_ids}")
        # print(f"input_mask : {input_mask}")

        start = time.perf_counter()
        MAX_LENGTH = 9
        with torch.no_grad():
            _, s, g = model(input_ids=input_ids,
                            token_type_ids=segment_ids,
                            state_positions=state_position_ids,
                            attention_mask=input_mask,
                            max_value=MAX_LENGTH,
                            )
        # print(s.shape, s)
        # print(g.shape, g)
        _, op_ids = s.view(-1, len(op2id)).max(-1)
        # print(f"op_ids : {op_ids}")

        if g.size(1) > 0:
            generated = g.squeeze(0).max(-1)[1].tolist()
        else:
            generated = []
        # print(f"g.shape : {g.shape}, before_generated : {generated}")

        pred_ops = [id2op[a] for a in op_ids.tolist()]
        # print(pred_ops)
        
        generated, last_dialog_state = postprocessing(slot_meta, pred_ops, last_dialog_state,
                                                      generated, tokenizer, op_code)


        end = time.perf_counter()
        wall_times.append(end - start)
        pred_state = []
        for k, v in last_dialog_state.items():
            try:
                v = v.split('[UNK]')[0].strip()
            except:
                v = v
            # print(v)
            pred_state.append('-'.join([k, v]))

        key = str(i.id) + '-' +str(i.turn_id)
        results[key] = pred_state
        # print(f"{di}, {key} : {pred_state}, generated : {generated}, {len(generated)}")

        # # postprocess to results
        # for k, v in results.items():
        #     temp = []
        #     for vv in v:
        #         value = 
        #         try:
        #             temp.append([vv.split('[UNK]')[0].strip()])
        #         except:
        #             # print(vv)
        #             if value:
        #                 temp.append([vv])
        #     # print(temp)
        #     results[k] = temp
        #     print(temp)

    output_path = '/opt/ml/code/p3-dst-chatting-day/SomDST/predictions/'
    os.makedirs(output_path, exist_ok=True)
    output_file = args.model_ckpt_path.split('/')[-1].split('.')[0] + '_outputs.csv'
    with open(output_path + output_file, 'w', encoding='UTF-8') as f:
        json.dump(results, f, ensure_ascii=False)
    print(f"{output_path + output_file} is saved!")