コード例 #1
0
ファイル: test_gradient.py プロジェクト: rev112/pcml_mnist
def learn_with_gradient_testing(fname):
    matFileContent = scipy.io.loadmat(fname) # corresponding MAT file
    x_train = np.array(matFileContent['Xtrain'].tolist())
    t_train = np.array(matFileContent['Ytrain'].flatten())

    d = x_train.shape[1]
    hidden_layers_list = [10]

    mlp = Mlp(hidden_layers_list, d, True)
    stopping_criterion = Mlp.BasicStoppingCriterion(0.05, 10)

    error_data = mlp.train_network(x_train, t_train, x_train, t_train,
            stopping_criterion)

    print "Error data:"
    error_data = map(lambda x: repr(x), error_data)
    print reduce(lambda x, y: x+'\n'+y, error_data)

    print "Log error:"
    print mlp.get_input_error(x_train, t_train)

    try:
        x_test = np.array(matFileContent['Xtest'].tolist())
        t_test = np.array(matFileContent['Ytest'].flatten())
        print "Test log error:"
        print mlp.get_input_error(x_test, t_test)
    except:
        pass
コード例 #2
0
ファイル: train_mlp.py プロジェクト: rev112/pcml_mnist
def learn(argv):
    fname = argv[1]
    if len(argv) == 3:
        hidden_layers_list = eval(argv[2])
    else:
        hidden_layers_list = [10]

    matFileContent = scipy.io.loadmat(fname) # corresponding MAT file
    x_train = np.array(matFileContent['TrainSet'].tolist())
    t_train = np.array(matFileContent['TrainClass'].tolist())

    x_valid = np.array(matFileContent['ValidSet'].tolist())
    t_valid = np.array(matFileContent['ValidClass'].tolist())

    d = x_train.shape[1]

    mlp = Mlp(hidden_layers_list, d)
    stopping_criterion = Mlp.EarlyStoppingCriterion(5, 1e-5)
    #stopping_criterion = Mlp.BasicStoppingCriterion(0.001, 100)

    (error_data, best_epoch) = mlp.train_network(x_train, t_train, 
            x_valid, t_valid, stopping_criterion)

    lrate = defaults.LEARNING_RATE_DEFAULT
    mterm = defaults.MOMENTUM_TERM_DEFAULT 
    terms = str(lrate)[2:]+"_"+str(mterm)[2:]
    arch_desc = reduce(lambda x, y:str(x)+"_"+str(y), 
            hidden_layers_list, "")
    plt_file = 'plots/errors_' + terms + arch_desc + '.png'
    plot_network_errors(error_data, best_epoch, plt_file)

    print "Train log error and accuracy:"
    print mlp.get_input_error(x_train, t_train), \
            mlp.get_accuracy(x_train, t_train), "%"
    print "Valid log error and accuracy:"
    print mlp.get_input_error(x_valid, t_valid), \
            mlp.get_accuracy(x_valid, t_valid), "%"

    x_test = np.array(matFileContent['TestSet'].tolist())
    t_test = np.array(matFileContent['TestClass'].tolist())
    print "Test log error and accuracy:"
    print mlp.get_input_error(x_test, t_test), \
            mlp.get_accuracy(x_test, t_test), "%"

    pickle.dump(mlp, open('trained_network.dat', 'wb'))