def inspectOne(times=0, printCorrect=True):
    isAllCorrect = True
    if times > 10000: # to prevent infinite loops
        return
    ind = np.random.randint(0, len(val.index))
    ind = (val.index[ind][0], slice(None))

    string_input = values_test.loc[(slice(None), ind[0]), :]
    preds = model.predict(
        getData(val.loc[ind], string_input, feat_x, strings_x))
    preds = [np.argmax(x, axis=-1) for x in preds]
    #predicted string
    print("Predicted String Q", string_input[strings_x[0]][0], "from",
          "-".join(str(x) for x in string_input.index.values[0]))

    if (np.argmax(string_input[strings_y[0]][0], axis=-1) == preds[0]).all():
        if printCorrect: print(colors.ok + '✅' + colors.close + "Segmentation")
    else:
        print(colors.fail + '❌' + colors.close + "Segmentation")
        isAllCorrect = False
        print("    T", utf2bw(string_input[strings_y[0]][0]))
        print("     ", utf2bw(ctable.decode(preds[0][0], calc_argmax=False)))


#     print('Q', utf2bw(pretty_join(rowx)))

    rowy = dict()
    for i, v in enumerate(feat_y):
        rowy[v] = {"correct": val[v].loc[ind]}
        res = np.zeros((SENTLEN, rowy[v]["correct"].shape[1]))
        for ii, c in enumerate(preds[i + 1][0]):
            res[ii, c] = 1
        rowy[v]["pred"] = pd.DataFrame(res, columns=val[v].columns)
        results = []
        if (rowy[v]["correct"].values == rowy[v]["pred"].values).all():
            if printCorrect: print(colors.ok + '✅' + colors.close + v)
        else:
            isAllCorrect = False
            print(colors.fail + '❌' + colors.close + v, end=' ')
            #             results.append(colors.fail + '☒' + colors.close)
            results.append('T ' + pretty_join(rowy[v]["correct"]))
            results.append(pretty_join(rowy[v]["pred"]))
            print(' '.join(results))
    if isAllCorrect or times < 10:
        print("")
        inspectOne(times + 1, printCorrect)
def calc_accuracy(name, model,mydata,data_length, debug=True):
    results = pd.DataFrame([], columns=[[x for x in model.output_names*3]+["agg"],[x for x in ["acc","pred","actu"]*len(model.output_names)]+["agg"]])
    results.sort_index(axis=1, inplace=True)
    for ite in range(data_length):
        r = dict()
        for i, v in enumerate(model.output_names):
            if v not in strings_y:
                continue

            r[(v,"acc")] = (np.argmax(mydata[1][v][ite], axis=-1) == preds[i][ite]).all()
            r[(v,"pred")] = utf2bw(ctable.decode(preds[i][ite], calc_argmax=False)).strip(" ") if debug else ""
            r[(v,"actu")] = utf2bw(ctable.decode(mydata[1][v][ite], calc_argmax=True)).strip(" ") if debug else ""

        for i, v in enumerate(model.output_names):
            if v in strings_y:
                continue
            correct, pred = np.argmax(mydata[1][v][ite], axis=-1), preds[i][ite]
            r[(v,"acc")] = (correct == pred).all()
            r[(v,"actu")] = pretty_join2(correct,val[v].columns,v) if debug else ""
            r[(v,"pred")] = pretty_join2(pred,val[v].columns,v) if debug else ""
        r[("agg","agg")] = sum([r[(v,"acc")] for i, v in enumerate(model.output_names)]) == len(model.output_names)
        results.loc[ite] = r
    accuracies.loc[len(accuracies)] = [name]+list(np.average(results.filter(regex="acc|agg").values.astype(int),axis=0))
    return results
Exemple #3
0
                ctable.decode(mydata[0][strings_x[0]][ite],
                              calc_argmax=True), "from",
                "-".join(str(x) for x in values_test.index.values[ite]))

            for ii, v in enumerate(model.output_names):
                if v not in strings_y:
                    continue

                if not (np.argmax(mydata[1][v][ite], axis=-1)
                        == preds[ii][ite]).all():
                    print('❌' + v)
                    isAllCorrect = False
                    print(
                        "    T",
                        utf2bw(
                            ctable.decode(mydata[1][v][ite],
                                          calc_argmax=True)))
                    print(
                        "     ",
                        utf2bw(ctable.decode(preds[ii][ite],
                                             calc_argmax=False)))

            rowy = dict()
            for ii, v in enumerate(model.output_names):
                if v in strings_y:
                    continue
                rowy[v] = {"correct": np.argmax(mydata[1][v][ite], axis=-1)}
                rowy[v]["pred"] = preds[ii][ite]
                results = []
                #         print(val[v].columns)
                if not (rowy[v]["correct"] == rowy[v]["pred"]).all():
# In[207]:

# for ite in range(len(mydata[0][strings_x[0]])):
for ite in range(10):
    print("\nPredicted String Q", ctable.decode(mydata[0][strings_x[0]][ite], calc_argmax=True), "from",
          "-".join(str(x) for x in values_test.index.values[ite]))
    
    for i, v in enumerate(model.output_names):
        if v not in strings_y:
            continue

        if not (np.argmax(mydata[1][v][ite], axis=-1) == preds[i][ite]).all():
            print(colors.fail + '❌' + colors.close + v)
            isAllCorrect = False
            print("    T", utf2bw(ctable.decode(mydata[1][v][ite], calc_argmax=True)))
            print("     ", utf2bw(ctable.decode(preds[i][ite], calc_argmax=False)))

    rowy = dict()
    for i, v in enumerate(model.output_names):
        if v in strings_y:
            continue
        rowy[v] = {"correct": np.argmax(mydata[1][v][ite], axis=-1)}
        rowy[v]["pred"] = preds[i][ite]
        results = []
#         print(val[v].columns)
        if not (rowy[v]["correct"] == rowy[v]["pred"]).all():
            isAllCorrect = False
            print(colors.fail + '❌' + colors.close + v, end=' ')
            #             results.append(colors.fail + '☒' + colors.close)
            results.append('T ' + pretty_join2(rowy[v]["correct"],val[v].columns,v))