コード例 #1
0
def test_f(args, y, output):
    correct = [0] * args.seq_length
    total = [0] * args.seq_length
    if args.label_type == 'one_hot':
        y_decode = one_hot_decode(y)
        output_decode = one_hot_decode(output)
    elif args.label_type == 'five_hot':
        y_decode = five_hot_decode(y)
        output_decode = five_hot_decode(output)
    for i in range(np.shape(y)[0]):
        y_i = y_decode[i]
        output_i = output_decode[i]
        # print(y_i)
        # print(output_i)
        class_count = {}
        for j in range(args.seq_length):
            if y_i[j] not in class_count:
                class_count[y_i[j]] = 0
            class_count[y_i[j]] += 1
            total[class_count[y_i[j]]] += 1
            if y_i[j] == output_i[j]:
                correct[class_count[y_i[j]]] += 1
        #  set_trace()
    #  return [float(correct[i]) / total[i] if total[i] > 0. else 0. for i in range(1, int(args.seq_length/args.n_classes))]
    return [
        float(correct[i]) / total[i] if total[i] > 0. else 0.
        for i in range(1, 11)
    ], total[1:11]
コード例 #2
0
ファイル: one_shot_learning.py プロジェクト: shamanez/ntm
def test_f(args, y, output):

    correct = [0] * args.seq_length  #correctly predicted

    total = [0] * args.seq_length  #total predicted
    if args.label_type == 'one_hot':
        y_decode = one_hot_decode(y)  #getting the index of the arg max
        output_decode = one_hot_decode(output)  #getting the index of argmax

    elif args.label_type == 'five_hot':
        y_decode = five_hot_decode(y)
        output_decode = five_hot_decode(output)

    for i in range(
            np.shape(y)[0]
    ):  #this is iterating through the each predicted example in the batch

        y_i = y_decode[i]  #one y_i have 50 elementns
        print("Printing the correct classes of the sequence", y_i)
        output_i = output_decode[i]
        # print(y_i)
        # print(output_i)
        class_count = {}
        for j in range(args.seq_length):  #now for each time step we iterate.
            print(j)
            if y_i[j] not in class_count:  #get the first class in the starting time step. Check whether the sequence saw it before
                class_count[y_i[j]] = 0  #start for the that class with sero
            class_count[y_i[
                j]] += 1  #add one for the class    #each time when this sees a class it will up the counts
            print("Printing the class counts", class_count)
            print(
                "printing the class cout of the current correct-class in the sequence",
                class_count[y_i[j]])
            total[class_count[y_i[j]]] += 1
            print(total)

            if y_i[j] == output_i[
                    j]:  #if corerctly predicted the current time step one
                correct[class_count[y_i[
                    j]]] += 1  #This is to basically find how many times networks see a class and how many times network correctly predicted a class.
            print("Printing the correctness thing", correct)

    #basically here we calculate end of each time step how many times I have seen this examples and how many times my network predicted correctly.

    #here total is a [0,8,2,3,......49]  of there  8 in second position is in the batch the network has seen same class for twice while and .
    return [
        float(correct[i]) / total[i] if total[i] > 0. else 0.
        for i in range(1, 11)
    ]  # accuracy is get by how many time steps in a back has seen sa