Esempio n. 1
0
def test_srbm_regression(show_figure=False, block_figure_on_end=False):
    print("========== Test Supervised RBM for Regression ==========")

    from sklearn.metrics import mean_squared_error
    from sklearn.linear_model import LinearRegression

    np.random.seed(random_seed())

    (x_train, y_train), (x_test, y_test) = demo.load_mnist()

    x = np.vstack([x_train, x_test])
    y = np.concatenate([y_train, y_test])

    learning_display = Display(
        title="Learning curves",
        dpi='auto',
        layout=(3, 1),
        freq=1,
        show=show_figure,
        block_on_end=block_figure_on_end,
        monitor=[
            {
                'metrics': ['recon_err', 'val_recon_err'],
                'type': 'line',
                'labels': ["training recon error", "validation recon error"],
                'title': "Reconstruction Errors",
                'xlabel': "epoch",
                'ylabel': "error",
            },
            {
                'metrics': ['loss', 'val_loss'],
                'type': 'line',
                'labels': ["training loss", "validation loss"],
                'title': "Learning Losses",
                'xlabel': "epoch",
                'ylabel': "loss",
            },
            {
                'metrics': ['err', 'val_err'],
                'type': 'line',
                'labels': ["training error", "validation error"],
                'title': "Prediction Errors",
                'xlabel': "epoch",
                'ylabel': "error",
            },
            # {'metrics': ['loglik_csl', 'val_loglik_csl'],
            #  'type': 'line',
            #  'labels': ["training loglik (CSL)", "validation loglik (CSL)"],
            #  'title': "Loglikelihoods using CSL",
            #  'xlabel': "epoch",
            #  'ylabel': "loglik",
            #  },
        ])

    filter_display = Display(title="Receptive Fields",
                             dpi='auto',
                             layout=(1, 1),
                             figsize=(8, 8),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             monitor=[
                                 {
                                     'metrics': ['filters'],
                                     'title': "Receptive Fields",
                                     'type': 'img',
                                     'num_filters': 15,
                                     'disp_dim': (28, 28),
                                     'tile_shape': (3, 5),
                                 },
                             ])

    hidden_display = Display(title="Hidden Activations",
                             dpi='auto',
                             layout=(1, 1),
                             figsize=(8, 8),
                             freq=1,
                             show=show_figure,
                             block_on_end=block_figure_on_end,
                             monitor=[
                                 {
                                     'metrics': ['hidden_activations'],
                                     'title': "Hidden Activations",
                                     'type': 'img',
                                     'data': x_train[:100],
                                 },
                             ])

    early_stopping = EarlyStopping(monitor='val_loss', patience=2, verbose=1)
    filepath = os.path.join(model_dir(),
                            "male/sRBM/mnist_{epoch:04d}_{val_loss:.6f}.pkl")
    checkpoint = ModelCheckpoint(filepath,
                                 mode='min',
                                 monitor='val_loss',
                                 verbose=0,
                                 save_best_only=True)
    model = SupervisedRBM(task='regression',
                          num_hidden=15,
                          num_visible=784,
                          batch_size=100,
                          num_epochs=4,
                          w_init=0.01,
                          learning_rate=0.01,
                          momentum_method='sudden',
                          weight_cost=0.0,
                          inference_engine='variational_inference',
                          approx_method='first_order',
                          metrics=['recon_err', 'loss', 'err'],
                          callbacks=[
                              filter_display, learning_display, hidden_display,
                              early_stopping, checkpoint
                          ],
                          cv=[-1] * x_train.shape[0] + [0] * x_test.shape[0],
                          random_state=random_seed(),
                          verbose=1)

    model.fit(x, y)

    print("Test reconstruction error = %.4f" %
          model.get_reconstruction_error(x_test).mean())
    print("Test loss = %.4f" % model.get_loss(x_test, y_test))

    print("=========== Predicted by sRBM ============")
    print("Train MSE = {0:>1.4f}\tTest MSE = {1:>1.4f}".format(
        -model.score(x_train, y_train), -model.score(x_test, y_test)))

    # fit a Linear Regressor
    lr = LinearRegression()
    lr.fit(x_train, y_train)
    print("=========== Predicted by Linear Regressor ============")
    print("Train MSE = {0:>1.4f}\tTest MSE = {1:>1.4f}".format(
        mean_squared_error(y_train, lr.predict(x_train)),
        mean_squared_error(y_test, lr.predict(x_test))))