예제 #1
0
파일: main.py 프로젝트: manthanb/thesis
def test():
    obstacle, wall_cw, wall_awc = Obstacle(), WallCW(), WallACW()
    obstacle_params, wall_cw_params, wall_acw_params = torch.load(
        'program_memory/move.pt'), torch.load(
            'program_memory/cw.pt'), torch.load('program_memory/acw.pt')

    networks = [obstacle, wall_cw, wall_awc]
    params = [obstacle_params, wall_cw_params, wall_acw_params]

    hnm = HNM(10, 14, networks, params)
    hnm.load_state_dict(torch.load('learned_params/hnm.pt'))

    testX, testY = getTestData()
    s = torch.from_numpy(np.array(testX[108:109][0])).float().unsqueeze(0)
    l = np.array(testY[108:109][0])

    print(s.size())
    # print(l.size())

    (read_weights, write_weights) = hnm._initialise()

    plt.matshow(hnm.Memory.detach().numpy())
    plt.show()

    correct = 0

    for i in range(s.size()[1]):
        (out, read_weights,
         write_weights) = hnm.forward(s[:, i, :], read_weights, write_weights)
        values = out.detach().numpy()
        if np.argmax(values) == np.argmax(l[i]): correct += 1
        plt.matshow(hnm.Memory.detach().numpy())
        plt.show()

    print(correct)
예제 #2
0
파일: main.py 프로젝트: manthanb/thesis
def compare():

    obstacle, wall_cw, wall_awc = Obstacle(), WallCW(), WallACW()
    obstacle_params, wall_cw_params, wall_acw_params = torch.load(
        'program_memory/move.pt'), torch.load(
            'program_memory/cw.pt'), torch.load('program_memory/acw.pt')
    networks = [obstacle, wall_cw, wall_awc]
    params = [obstacle_params, wall_cw_params, wall_acw_params]
    hnm = HNM(10, 14, networks, params)
    hnm.load_state_dict(torch.load('learned_params/hnm.pt'))

    ntm = NTM(10, 14)
    ntm.load_state_dict(torch.load('learned_params/ntm.pt'))

    lstm = LSTM(14, 64, 3, 1)
    lstm.load_state_dict(torch.load('learned_params/lstm.pt'))

    testX, testY = getTestData()

    hnm_correct, ntm_correct, lstm_correct = 0, 0, 0
    totSamples = 0

    for i in range(0, 25):

        s = torch.from_numpy(np.array(testX[i:i + 1][0])).float().unsqueeze(0)
        s_lstm = s.view(s.size()[0], s.size()[2], -1)
        l = np.array(testY[i:i + 1][0])

        print(i)

        (hnm_read_weights, hnm_write_weights) = hnm._initialise()
        (ntm_read_weights, ntm_write_weights) = ntm._initialise()
        lstm_h = lstm.h0.expand(s_lstm.size()[0], 64)
        lstm_c = lstm.c0.expand(s_lstm.size()[0], 64)

        for j in range(s.size()[1]):

            (hnm_out, hnm_read_weights,
             hnm_write_weights) = hnm.forward(s[:, j, :], hnm_read_weights,
                                              hnm_write_weights)
            (ntm_out, ntm_read_weights,
             ntm_write_weights) = ntm.forward(s[:, j, :], ntm_read_weights,
                                              ntm_write_weights)
            lstm_h, lstm_c, lstm_out = lstm.forward(s_lstm[:, :, j], lstm_h,
                                                    lstm_c)

            if np.argmax(hnm_out.detach().numpy()) == np.argmax(l[j]):
                hnm_correct += 1
            if np.argmax(ntm_out.detach().numpy()) == np.argmax(l[j]):
                ntm_correct += 1
            if np.argmax(lstm_out.detach().numpy()) == np.argmax(l[j]):
                lstm_correct += 1

            totSamples += 1

    print(hnm_correct, ntm_correct, lstm_correct)
    print(totSamples)
예제 #3
0
파일: main.py 프로젝트: manthanb/thesis
def trainHNM():
    obstacle, wall_cw, wall_awc = Obstacle(), WallCW(), WallACW()
    obstacle_params, wall_cw_params, wall_acw_params = torch.load(
        'program_memory/move.pt'), torch.load(
            'program_memory/cw.pt'), torch.load('program_memory/acw.pt')

    networks = [obstacle, wall_cw, wall_awc]
    params = [obstacle_params, wall_cw_params, wall_acw_params]

    hnm = HNM(10, 14, networks, params)

    X, y = getData()

    hnm.train(X, y, 1)
예제 #4
0
파일: main.py 프로젝트: manthanb/thesis
def trainHNM():
	obstacle, wall_cw, wall_awc = Obstacle(), WallCW(), WallACW()
	obstacle_params, wall_cw_params, wall_acw_params = torch.load('program_memory/move.pt'), torch.load('program_memory/cw.pt'), torch.load('program_memory/acw.pt')

	networks = [obstacle, wall_cw, wall_awc]
	params = [obstacle_params, wall_cw_params, wall_acw_params]

	hnm = HNM(10, 14, networks, params)

	X, y = [], []
	for i in range(10):
		tempX, tempY = getData("data/observations_"+str(i*500)+".npy", "data/actions_"+str(i*500)+".npy")
		X.extend(tempX)
		y.extend(tempY)

	print(len(X), len(y))

	hnm.train(X, y, 1)