Exemplo n.º 1
0
LABEL_PATH = r'/home/xujingning/ocean/ocean_data/label.csv'
DATA_PATH = '/home/xujingning/ocean/ocean_data/data_img/'
MODEL_PATH = '/home/xujingning/ocean/ocean_data/VGG/'
EPOCH = 1
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


def add_noise(_img):
    #n_img = _img.copy()
    noise = np.random.randint(5, size=(224, 224, 3), dtype='uint8')
    return noise + _img


labels, imgs, _ = get_data(DATA_PATH, LABEL_PATH)
clf = get_classifier(MODEL_PATH, 'VGG_net')
total_imgs_train, total_imgs_test, total_labels_train, total_labels_test = train_test_split(
    imgs, labels, test_size=0.2, random_state=42)
# print(len(total_imgs_train), len(total_labels_train), len(total_imgs_test), len(total_labels_test))
# nsamples, nx, ny, nz = total_imgs_train.shape
# d2_train_dataset = total_imgs_train.reshape((nsamples, nx*ny*nz))
# #nsamples, nx, ny, nz = total_imgs_test.shape
# #d2_test_dataset = total_imgs_test.reshape((nsamples, nx*ny*nz))
#
# xx, yy = SMOTE().fit_sample(d2_train_dataset, total_labels_train)
# #xxx, yyy = ADASYN().fit_sample(d2_test_dataset, total_labels_test)
#
# xx = xx.reshape((len(xx), nx, ny, nz))
# #xxx = xxx.reshape((len(xx), nx, ny, nz))
# print(len(xx), len(yy))
# print(xx.shape)
sys.path.append('../')
from prediction.predict import get_prediction
from training.train import get_classifier, train_model
from data_IO.data_reader import get_data
from statistic.calculation import confusion_matrix

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

EPOCH = 100
DATA_PATH = '/home/liuyajun/yolo_motion/yolo/yolo_motion/output/w03/'
for k in range(0, 5):
    MODEL_PATH = '/home/liuyajun/concat_net/model/yolo_motion/w03/k' + str(
        k) + '/'
    TRAIN_LABEL_PATH = '/home/liuyajun/concat_net/data/kfold_label/train_k' + str(
        k) + '_label.csv'
    TEST_LABEL_PATH = '/home/liuyajun/concat_net/data/kfold_label/test_k' + str(
        k) + '_label.csv'

    clf = get_classifier(MODEL_PATH, 'alex_net')
    if not os.path.exists(MODEL_PATH):
        print('Training k%i ....' % k)
        train_label, train_data, _ = get_data(DATA_PATH, TRAIN_LABEL_PATH)
        for i in range(EPOCH):
            train_model(clf, {'X1': train_data}, train_label)

    test_label, test_data, _ = get_data(DATA_PATH, TEST_LABEL_PATH)
    pred = get_prediction(clf, {'X1': test_data}, test_label)
    print('Result of k%i:' % k)
    confusion_matrix(pred, test_label, show_mat=True)
    del clf