Esempio n. 1
0
def plain_run(alltrain, test_set, labels):
    testaccuracy, trainaccuracy, trainaccuracy_last, test_cost, train_cost, times = [], [], [], [], [], []  # for saving results and to generate a CV_log.txt file
    # split whole "train" set into train and validation
    train, val = datahandler.split_data(alltrain.data, alltrain.labels, config.val_split)
    config.disp_train = int(train.num_examples / (config.display_steps * config.batch_size))  # steps per epoch regardless batch size
    dataset = datahandler.make_datasets(train, val, test_set)

    testacc, trainacc, trainacc_last, duration, valacc =  \
        CNN_framework.run(config,
                          dataset=dataset,
                          classlabels=labels)
    # plot histograms
    if config.plot:
        plotlib.plot_histogram(dataset.test.labels, 'Test data', config.logspath, 'test_histogram', labels, export=config.export)
        plotlib.plot_histogram(dataset.train.labels, 'Training data', config.logspath, 'train_histogram', labels, export=config.export)
        plotlib.plot_histogram(dataset.validation.labels, 'Validation data', config.logspath, 'validation_histogram', labels, export=config.export)
    print("test accuracy: %s" % testacc)
    print("train accuracy: %s" % trainacc)
    print("last train acc: %s" % trainacc_last)

    if config.write_file:
        cvlog = open(config.logspath + 'CV_log.txt', 'w+')
        cvlog.write("test accuracies: %s\n" % testacc)
        cvlog.write("train accuracies: %s\n" % trainacc)
        cvlog.write("last train acc: %s\n" % trainacc_last)
        cvlog.write("average test accuracy: %s" % testacc)
        cvlog.write("average training time: %s" % duration)
        cvlog.close()

    return testacc
Esempio n. 2
0
def crossvalidate_run(train_dataset, test_set, labels, splits=10):
    testaccuracy, trainaccuracy, trainaccuracy_last, test_cost, train_cost, times = [], [], [], [], [], []  # for saving results of different runs
    """function to use 'splits'-fold cross validation, writing results to CV_log.txt"""
    basepath = config.logspath[:-1] + "_cv" + str(splits) + '/'
    fold = 0
    kf = KFold(n_splits=splits, shuffle=True)  # initialize k-fold cross validation

    if config.write_file or config.plot:
        if not os.path.isdir(basepath):
            os.makedirs(basepath)  # make sure directory exists for writing the log file
        cvlog = open(basepath + '/CV_log.txt', 'w+')

    # get the data
    if config.plot:  # test set stays same, but train/val is different for every split
        plotlib.plot_histogram(test_set.labels, 'Test data', basepath, 'test_histogram', labels)

    # cross validation loop
    for train_idx, val_idx in kf.split(train_dataset.data, train_dataset.labels):
        # construct data set with train/val data for this "fold"
        train_set = datahandler.make_dataset(train_dataset.data[train_idx], train_dataset.labels[train_idx])
        val_set = datahandler.make_dataset(train_dataset.data[val_idx], train_dataset.labels[val_idx])
        # set own path for each run
        config.logspath = basepath + 'fold_' + str(fold) + '/'
        config.disp_train = int(train_set.num_examples / (config.display_steps * config.batch_size))  # steps per epoch regardless batch size
        # run optimization and evaluation procedure
        testacc, trainacc, trainacc_last, duration, valacc = \
            CNN_framework.run(config,
                              dataset=datahandler.make_datasets(train_set, val_set, test_set),
                              classlabels=labels)
        # plot histograms
        if config.plot:
            plotlib.plot_histogram(train_set.labels, 'Training data', config.logspath, 'train_histogram',
                                   labels)
            plotlib.plot_histogram(val_set.labels, 'Validation data', config.logspath,
                                   'validation_histogram',
                                   labels)
        testaccuracy.append(testacc)
        trainaccuracy.append(trainacc)
        trainaccuracy_last.append(trainacc_last)
        times.append(duration)
        print("test accuracy in fold %s: %s " % (fold, testacc))
        print("train accuracy in fold %s: %s" % (fold, trainacc))
        print("last train acc. in fold %s: %s" % (fold, trainacc_last))
        fold += 1
    mean_testacc = np.mean(testaccuracy)
    if config.write_file:
        cvlog.write("test accuracies: %s\n" % testaccuracy)
        cvlog.write("train accuracies: %s\n" % trainaccuracy)
        cvlog.write("last train acc: %s\n" % trainaccuracy_last)
        cvlog.write("average test accuracy: %s" % mean_testacc)
        cvlog.write("average training time: %s" % np.mean(times))
        cvlog.close()
    return mean_testacc
def plain_run(alltrain, test_set, labels):
    # split whole "train" set into train and validation
    train, val = datahandler.split_data(alltrain.data, alltrain.labels,
                                        config.val_split)
    config.disp_train = int(
        train.num_examples /
        (config.display_steps *
         config.batch_size))  # steps per epoch regardless batch size
    dataset = datahandler.make_datasets(train, val, test_set)

    testacc, trainacc, trainacc_last, duration, valacc =  \
        CNN_framework.run(config,
                          dataset=dataset,
                          classlabels=labels)
    # plot histograms
    if config.plot:
        plotlib.plot_histogram(dataset.test.labels,
                               'Test data',
                               config.logspath,
                               'test_histogram',
                               labels,
                               export=config.export)
        plotlib.plot_histogram(dataset.train.labels,
                               'Training data',
                               config.logspath,
                               'train_histogram',
                               labels,
                               export=config.export)
        plotlib.plot_histogram(dataset.validation.labels,
                               'Validation data',
                               config.logspath,
                               'validation_histogram',
                               labels,
                               export=config.export)
    print("train accuracy: %s" % trainacc)
    print("last train acc: %s" % trainacc_last)
    print("val acc: %s" % valacc)
    print("test acc: %s" % testacc)

    return testacc, trainacc, valacc