def linear_regression(a=1.0, b=0.0):
    X = np.linspace(-100, 100, 200)
    X = X.reshape((-1, 1))
    [train_x, test_x] = split_data(X, ratio=0.8, random=True)
    train_y = a * train_x + b
    test_y = a * test_x + b

    i = Input(1)
    x = Dense(1)(i)

    # define trainer
    trainer = Trainer(loss='mse',
                      optimizer=Adam(learning_rate=0.2),
                      batch_size=50,
                      epochs=50)

    # create model
    model = Sequential(i, x, trainer)

    model.summary()

    # training process
    model.fit(train_x, train_y)

    # predict
    y_hat = model.predict(test_x)
    plt.plot(test_x, test_y, 'b')
    plt.plot(test_x, y_hat, 'r')
    plt.show()
Exemple #2
0
def linear_classification(a=1.0, b=0.0, graph=False):

    # prepare data
    x = np.linspace(-100, 100, 200)
    y = a * x + b
    X = np.array(list(zip(x, y))) + np.random.randn(200, 2) * 100
    Y = to_one_hot(np.where(a * X[:, 0] + b > X[:, 1], 1, 0))
    (train_x, train_y), (test_x, test_y) = split_data(X,
                                                      Y,
                                                      ratio=0.8,
                                                      random=True)

    # build simple FNN
    i = Input(2)
    x = Dense(2, activation='softmax')(i)

    # define trainer
    trainer = Trainer(loss='cross_entropy',
                      optimizer=Adam(learning_rate=0.05),
                      batch_size=50,
                      epochs=50,
                      metrics=['accuracy'])

    # create model
    model = Sequential(i, x, trainer)

    model.summary()

    # training process
    model.fit(train_x, train_y)
    print(model.evaluate(test_x, test_y))

    if graph:
        plt.plot(model.history['loss'])
        plt.show()

        # predict
        y_hat = model.predict(test_x)
        y_hat = np.argmax(y_hat, axis=1)
        simple_plot(test_x, y_hat, a, b)
Exemple #3
0
def binary_classification():
  def separate_label(data):
    X = normalize(data[:, :2].astype('float32'))
    Y = np.where(data[:, 2] == b'black', 0, 1)
    return X, Y

  # prepare train data
  data_dir = "data/examples/binary_classification"
  train_data_path = os.path.join(data_dir, 'training.arff')
  train_data = load_arff(train_data_path)
  train_x, train_y = separate_label(train_data)
  train_y = to_one_hot(train_y)

  # build simple FNN
  i = Input(2)
  x = Dense(30, activation='relu')(i)
  x = Dense(30, activation='relu')(x)
  x = Dense(2, activation='softmax')(x)

  # define trainer
  trainer = Trainer(loss='cross_entropy', optimizer=Adam(clipvalue=1.0), batch_size=256, epochs=500, metrics=['accuracy'])

  # create model
  model = Sequential(i, x, trainer)

  model.summary()

  # training process
  model.fit(train_x, train_y)

  plt.plot(range(len(model.history['loss'])), model.history['loss'])
  plt.show()

  # predict
  test_data_path = os.path.join(data_dir, 'test.arff')
  test_data = load_arff(test_data_path)
  test_x, _ = separate_label(test_data)

  y_hat = model.predict(test_x)
  simple_plot(test_x, y_hat)
Exemple #4
0
def universal_approximation(f, x):
    [train_x, test_x] = split_data(x, ratio=0.8, random=True)
    train_y = f(train_x)

    test_x = np.sort(test_x, axis=0)
    test_y = f(test_x)

    # build simple FNN
    i = Input(1)
    x = Dense(50, activation='relu')(i)
    x = Dense(1)(x)

    # define trainer
    schedule = ExponentialDecay(initial_learning_rate=0.01, decay_rate=0.75)
    trainer = Trainer(loss='mse',
                      optimizer=Adam(learning_rate=schedule),
                      batch_size=50,
                      epochs=750)

    # create model
    model = Sequential(i, x, trainer)

    model.summary()

    # training process
    start = time.time()
    model.fit(train_x, train_y)
    print(time.time() - start)

    plt.plot(range(len(model.history['loss'])), model.history['loss'])
    plt.show()

    # predict
    y_hat = model.predict(test_x)
    plt.plot(test_x, test_y, 'b-', label='original')
    plt.plot(test_x, y_hat, 'r-', label='predicted')
    plt.legend()
    plt.show()