Beispiel #1
0
def eval_iter_callback(tensors, global_vars):
    if "punct_all_preds" not in global_vars.keys():
        global_vars["punct_all_preds"] = []
    if "punct_all_labels" not in global_vars.keys():
        global_vars["punct_all_labels"] = []
    if "capit_all_preds" not in global_vars.keys():
        global_vars["capit_all_preds"] = []
    if "capit_all_labels" not in global_vars.keys():
        global_vars["capit_all_labels"] = []
    if "all_subtokens_mask" not in global_vars.keys():
        global_vars["all_subtokens_mask"] = []

    GLOBAL_KEYS = ['punct_labels', 'capit_labels', 'punct_preds', 'capit_preds']
    for key in GLOBAL_KEYS:
        if key not in global_vars:
            global_vars[key] = []

    output = {}
    for k, v in tensors.items():
        name = k.split('~~~')
        if len(name) > 1:
            output[name[0]] = torch.cat(v)

    subtokens_mask = output['subtokens_mask'] > 0.5
    global_vars['punct_preds'].extend(tensor2list(torch.argmax(output['punct_logits'], axis=-1)[subtokens_mask]))
    global_vars['capit_preds'].extend(tensor2list(torch.argmax(output['capit_logits'], axis=-1)[subtokens_mask]))
    global_vars['punct_labels'].extend(tensor2list(output['punct_labels'][subtokens_mask]))
    global_vars['capit_labels'].extend(tensor2list(output['capit_labels'][subtokens_mask]))
def eval_iter_callback(tensors, global_vars):
    if "all_preds" not in global_vars.keys():
        global_vars["all_preds"] = []
    if "all_labels" not in global_vars.keys():
        global_vars["all_labels"] = []
    if "all_subtokens_mask" not in global_vars.keys():
        global_vars["all_subtokens_mask"] = []

    all_subtokens_mask, all_logits, all_labels = [], [], []

    for kv, v in tensors.items():
        if kv.startswith('logits'):
            for v_tensor in v:
                for logit_tensor in v_tensor:
                    all_logits.append(tensor2list(logit_tensor))

        elif kv.startswith('labels'):
            for v_tensor in v:
                for label_tensor in v_tensor:
                    all_labels.extend(tensor2list(label_tensor))

        elif kv.startswith('subtokens_mask'):
            for v_tensor in v:
                for subtokens_mask_tensor in v_tensor:
                    all_subtokens_mask.extend(tensor2list(subtokens_mask_tensor))

    all_preds = list(np.argmax(np.asarray(all_logits), 2).flatten())
    global_vars["all_preds"].extend(all_preds)
    global_vars["all_labels"].extend(all_labels)
    global_vars["all_subtokens_mask"].extend(all_subtokens_mask)
def eval_iter_callback(tensors, global_vars):
    if "all_preds" not in global_vars.keys():
        global_vars["all_preds"] = []
    if "all_labels" not in global_vars.keys():
        global_vars["all_labels"] = []

    logits_lists = []
    preds_lists = []
    labels_lists = []

    for kv, v in tensors.items():
        # for GLUE classification tasks
        if 'logits' in kv:
            for v_tensor in v:
                for logit_tensor in v_tensor:
                    logits_lists.append(tensor2list(logit_tensor))
        # for GLUE STS-B task (regression)
        elif 'preds' in kv:
            for v_tensor in v:
                for pred_tensor in v_tensor:
                    preds_lists.append(tensor2list(pred_tensor))
        if 'labels' in kv:
            for v_tensor in v:
                for label_tensor in v_tensor:
                    labels_lists.append(tensor2list(label_tensor))

    if len(logits_lists) > 0:
        preds = list(np.argmax(np.asarray(logits_lists), 1))
    elif len(preds_lists) > 0:
        preds = list(np.squeeze(np.asarray(preds_lists)))

    global_vars["all_preds"].extend(preds)
    global_vars["all_labels"].extend(labels_lists)
def eval_iter_callback(tensors, global_vars):
    if "punct_all_preds" not in global_vars.keys():
        global_vars["punct_all_preds"] = []
    if "punct_all_labels" not in global_vars.keys():
        global_vars["punct_all_labels"] = []
    if "capit_all_preds" not in global_vars.keys():
        global_vars["capit_all_preds"] = []
    if "capit_all_labels" not in global_vars.keys():
        global_vars["capit_all_labels"] = []
    if "all_subtokens_mask" not in global_vars.keys():
        global_vars["all_subtokens_mask"] = []

    all_subtokens_mask = []
    punct_all_logits, punct_all_labels = [], []
    capit_all_logits, capit_all_labels = [], []

    for kv, v in tensors.items():
        if 'Punctuation' in kv and 'logits' in kv:
            for v_tensor in v:
                for logit_tensor in v_tensor:
                    punct_all_logits.append(tensor2list(logit_tensor))

        elif kv.startswith('punct_labels'):
            for v_tensor in v:
                for label_tensor in v_tensor:
                    punct_all_labels.extend(tensor2list(label_tensor))

        elif 'Capitalization' in kv and 'logits' in kv:
            for v_tensor in v:
                for logit_tensor in v_tensor:
                    capit_all_logits.append(tensor2list(logit_tensor))

        elif kv.startswith('capit_labels'):
            for v_tensor in v:
                for label_tensor in v_tensor:
                    capit_all_labels.extend(tensor2list(label_tensor))

        elif kv.startswith('subtokens_mask'):
            for v_tensor in v:
                for subtokens_mask_tensor in v_tensor:
                    all_subtokens_mask.extend(
                        tensor2list(subtokens_mask_tensor))

    punct_all_preds = list(
        np.argmax(np.asarray(punct_all_logits), 2).flatten())
    global_vars["punct_all_preds"].extend(punct_all_preds)
    global_vars["punct_all_labels"].extend(punct_all_labels)

    capit_all_preds = list(
        np.argmax(np.asarray(capit_all_logits), 2).flatten())
    global_vars["capit_all_preds"].extend(capit_all_preds)
    global_vars["capit_all_labels"].extend(capit_all_labels)

    global_vars["all_subtokens_mask"].extend(all_subtokens_mask)
def eval_iter_callback(tensors, global_vars, eval_data_layer):
    if "all_preds" not in global_vars.keys():
        global_vars["all_preds"] = []
    if "all_labels" not in global_vars.keys():
        global_vars["all_labels"] = []

    logits_lists = []
    labels_lists = []

    for kv, v in tensors.items():
        if 'logits' in kv:
            for v_tensor in v:
                for logit_tensor in v_tensor:
                    logits_lists.append(tensor2list(logit_tensor))

        if 'labels' in kv:
            for v_tensor in v:
                for label_tensor in v_tensor:
                    labels_lists.append(tensor2list(label_tensor))

    preds = list(np.argmax(np.asarray(logits_lists), 1))
    global_vars["all_preds"].extend(preds)
    global_vars["all_labels"].extend(labels_lists)
def eval_iter_callback(tensors, global_vars, data_desc):

    if 'loss' not in global_vars:
        global_vars['loss'] = []
    if 'comp_res' not in global_vars:
        global_vars['comp_res'] = []
    if 'gating_labels' not in global_vars:
        global_vars['gating_labels'] = []
    if 'gating_preds' not in global_vars:
        global_vars['gating_preds'] = []

    point_outputs_max_list = []
    tgt_ids_list = []
    gate_outputs_max_list = []
    for tensor_name, values_list in tensors.items():
        if tensor_name.startswith('gating_labels'):
            for values in values_list:
                global_vars['gating_labels'].extend(tensor2list(values))
        elif tensor_name.startswith('point_outputs'):
            for values in values_list:
                p_max = torch.argmax(values, dim=-1)
                point_outputs_max_list.append(tensor2numpy(p_max))
        elif tensor_name.startswith('gate_outputs'):
            for values in values_list:
                g_max = torch.argmax(values, axis=-1)
                gate_outputs_max_list.append(tensor2numpy(g_max))
        elif tensor_name.startswith('tgt_ids'):
            for values in values_list:
                tgt_ids_list.append(tensor2numpy(values))

    comp_res_list = []
    for i in range(len(point_outputs_max_list)):
        mask_paddings = tgt_ids_list[i] == data_desc.vocab.pad_id
        comp_res = (point_outputs_max_list[i]
                    == tgt_ids_list[i]) | mask_paddings
        comp_res = np.all(comp_res, axis=-1, keepdims=False)
        comp_res_list.extend(comp_res.tolist())

    gate_outputs_max = np.concatenate(gate_outputs_max_list, axis=0).tolist()

    global_vars['comp_res'].extend(comp_res_list)
    global_vars['gating_preds'].extend(gate_outputs_max)