예제 #1
0
파일: network.py 프로젝트: cooliotseng/tfnn
 def save(self,
          name='new_model',
          path=None,
          global_step=None,
          replace=False):
     if not hasattr(self, '_saver'):
         self._saver = tfnn.NetworkSaver()
     self._saver.save(self, name, path, global_step, replace=replace)
예제 #2
0
파일: road_test.py 프로젝트: ten2net/tfnn
def compare_real(data_path):
    load_data = pd.read_pickle(data_path)
    s, f = 10700, 10900
    xs = load_data.iloc[s:f, 1:]
    ys = load_data.a[s:f]
    network_saver = tfnn.NetworkSaver()
    restore_path = '/tmp/'
    network = network_saver.restore(restore_path)
    prediction = network.predict(network.normalizer.fit_transform(xs))
    plt.plot(np.arange(xs.shape[0]), prediction, 'r--', label='predicted')
    plt.plot(np.arange(xs.shape[0]), ys, 'k-', label='real')
    plt.legend(loc='best')
    plt.show()
    network.sess.close()
예제 #3
0
 def save(self, path='/tmp/'):
     saver = tfnn.NetworkSaver()
     saver.save(self, path)
예제 #4
0
network = tfnn.ClfNetwork(input_size=mnist.train.images.shape[1],
                          output_size=mnist.train.labels.shape[1])

# add hidden layer
network.add_hidden_layer(n_neurons=100, activator=tfnn.nn.relu)

# add output layer
network.add_output_layer(activator=None)

# set optimizer. Default GradientDescent
network.set_optimizer()

# set evaluator for compute the accuracy, loss etc.
evaluator = tfnn.Evaluator(network)

# similar to sklearn, we have fit function
network.fit(mnist.train.images, mnist.train.labels, steps=2000)
print('trained network predict:')
print(network.predict(mnist.test.images[10:20]))
print('real data value:')
print(mnist.test.labels[10:20].argmax(axis=1))

# save network
network.save(name='model', path='tmp')

# reload network
saver = tfnn.NetworkSaver()
network2 = saver.restore(name='model', path='tmp')
print('\nLoaded network predict:')
print(network2.predict(mnist.test.images[10:20]))
예제 #5
0
파일: road_test.py 프로젝트: ten2net/tfnn
def test():
    network_saver = tfnn.NetworkSaver()
    restore_path = '/tmp/'
    network = network_saver.restore(restore_path)
    test_time = 60
    cars = []
    for i in range(8):
        cars.append(Car(i * -15))

    for i in range(test_time * 10):
        for j in range(len(cars)):
            if j == 0:
                if i < 1 * 10:
                    a = 0
                elif 1 * 10 <= i < 15 * 10:
                    a = 2
                elif 15 * 10 <= i < 20 * 10:
                    a = 0
                elif 20 * 10 <= i < 28 * 10:
                    a = -3
                elif 28 * 10 <= i < 30 * 10:
                    a = 0
                elif 30 * 10 <= i < 35 * 10:
                    a = 3
                elif 35 * 10 <= i < 37 * 10:
                    a = 0
                elif 37 * 10 <= i < 45 * 10:
                    a = -1
                elif 45 * 10 <= i < 50 * 10:
                    a = 2
                else:
                    a = 0
                cars[0].ps.append(cars[0].ps[-1] + cars[0].vs[-1] * 0.1 +
                                  1 / 2 * cars[0].acs[-1] * 0.1**2)
                v = cars[0].vs[-1] + 0.1 * a
                if v < 0:
                    v = 0
                cars[0].vs.append(v)
                cars[0].acs.append(a)
            else:
                if i <= 1 * 10:
                    a = 0
                else:
                    ss_data = cars[j].ss[-10:]
                    vs_data = cars[j].vs[-10:]
                    vs_l_data = cars[j - 1].vs[-10:]
                    xs_data = np.array(ss_data + vs_data + vs_l_data)
                    a = network.predict(
                        network.normalizer.fit_transform(xs_data))

                cars[j].ps.append(cars[j].ps[-1] + cars[j].vs[-1] * 0.1 +
                                  1 / 2 * cars[j].acs[-1] * 0.1**2)
                v = cars[j].vs[-1] + 0.1 * a
                if v < 0:
                    v = 0
                cars[j].vs.append(v)
                cars[j].acs.append(a)
                cars[j].ss.append(cars[j - 1].ps[-1] - cars[j].ps[-1])

    xs = list(range(test_time * 10 + 1))
    plt.figure(1)
    plt.subplot(411)
    for i in range(len(cars)):
        if i == 0:
            plt.plot(xs, cars[i].ps, 'k-')
        else:
            plt.plot(xs, cars[i].ps, 'r--')
    plt.ylabel('p (m)')
    # plt.legend(loc='best')
    plt.grid()

    plt.subplot(412)
    for i in range(len(cars)):
        if i == 0:
            plt.plot(xs, cars[i].acs, 'k-')
        else:
            plt.plot(xs, cars[i].acs, 'r--')
    plt.ylabel('a (m/s^2)')
    # plt.legend(loc='best')
    plt.grid()
    plt.subplot(413)
    for i in range(len(cars)):
        if i == 0:
            plt.plot(xs, cars[i].vs, 'k-')
        else:
            plt.plot(xs, cars[i].vs, 'r--')
    plt.ylabel('v (m/s)')
    # plt.legend(loc='best')
    plt.grid()
    plt.subplot(414)
    for i in range(1, len(cars)):
        plt.plot(xs, cars[i].ss, 'r--')
    plt.ylabel('space (m)')
    # plt.legend(loc='best')
    plt.grid()
    plt.show()