示例#1
0
def train_cnn(x_all_data, y_all_data):
    #output = create_cnn(3)
    output = create_cnn_3muti3(5)

    print_info("created cnn ...")
    print_info("start train ...")

    loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=Y, logits=output))
    optimizer = tf.train.AdadeltaOptimizer(learning_rate=0.001).minimize(loss)

    max_idx_p = tf.argmax(output, 1)
    max_idx_l = tf.argmax(Y, 1)

    correct_pred = tf.equal(max_idx_p, max_idx_l)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    saver = tf.train.Saver(max_to_keep=3)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        step = 0
        while True:
            batch_x, batch_y = get_next_batch(x_all_data, y_all_data)
            _, loss_ = sess.run([optimizer, loss],
                                feed_dict={
                                    X: batch_x,
                                    Y: batch_y,
                                    keep_prob: 0.75
                                })
            print(step, loss_)
            '''
            out = sess.run([output],feed_dict={X:batch_x,Y:batch_y,keep_prob:0.75})
            print("out:",out)
            y = sess.run([Y],{X:batch_x,Y:batch_y,keep_prob:0.75})
            print("y:",y)
            max_p = sess.run([max_idx_p],feed_dict={X:batch_x,Y:batch_y,keep_prob:0.75})
            max_l = sess.run([max_idx_l],feed_dict={X:batch_x,Y:batch_y,keep_prob:0.75})
            print("max_p:",max_p)
            print("max_l",max_l)
            break
            '''
            if step % 10 == 0:
                batch_x_test, batch_y_test = get_next_batch(
                    x_all_data, y_all_data)
                acc = sess.run(accuracy,
                               feed_dict={
                                   X: batch_x_test,
                                   Y: batch_y_test,
                                   keep_prob: 1.0
                               })
                print(step, acc)
                #break
                if acc > 0.9:
                    saver.save(sess, "./model_3/cnn.model", global_step=step)
                if acc > 0.98:
                    saver.save(sess, "./model_3/cnn.model", global_step=step)
                    break
            step += 1
示例#2
0
def read_test_data(file_in, size):
    print_info("start read test data...")
    data = read_signal_txt(file_in, 400, 400, 224)
    print_info("end read test data...")
    test_data = []
    x, y, z = data.shape
    for j in range(size, y + 1):
        for i in range(size, x + 1):
            test_cur = data[i - size:i, j - size:j, :]
            test_data.append(test_cur)
    #test_data = np.array(test_data)
    return test_data
示例#3
0
def read_test_data(file_in,size):
    print_info("start read test data...")
    data = read_signal_txt(file_in,2100,510,224)
    print_info("end read test data...")
    test_data = []
    x,y,z = data.shape
    for j in range(size,y+1):
        for i in range(size,x+1):
            test_cur = data[i-size:i,j-size:j,:] #以i为中心的一个矩形区域数据,相当于存储为一个batch,
                                                # 在CNN中对每个batch进行test
            test_data.append(test_cur)
    #test_data = np.array(test_data)    
    return test_data
示例#4
0
def read_all_train_data():
    x_data = []
    y_data = []
    data_dir = ('E:/Imaging/CNNROI/5_%d_%d.txt')
    for i in range(5):
        for j in range(1, 21):
            cur_txt = data_dir % (i, j)
            cur_data = read_signal_txt(cur_txt, 5, 5, 224)
            if i == 0:
                t = [1, 0, 0, 0, 0]
            if i == 1:
                t = [0, 1, 0, 0, 0]
            if i == 2:
                t = [0, 0, 1, 0, 0]
            if i == 3:
                t = [0, 0, 0, 1, 0]
            if i == 4:
                t = [0, 0, 0, 0, 1]
            y_data.append(t)
            x_data.append(cur_data)
    print_info("read done!")
    x_data = np.array(x_data, dtype=np.float32)
    y_data = np.array(y_data, dtype=np.float32)
    return x_data, y_data
示例#5
0
def read_all_train_data():
    x_data = []
    y_data = []
    data_dir = ('F:/Python/workshop/fishc/CNNROI/5_%d_%d.txt')
    for i in range(5):
        for j in range(1,21):
            cur_txt = data_dir%(i,j)
            cur_data = read_signal_txt(cur_txt,5,5,224)
            if i == 0: #以文件名中的编号作为区分类别的label,即训练中的Y值
                t = [1,0,0,0,0]
            if i == 1:
                t = [0,1,0,0,0]
            if i == 2:
                t = [0,0,1,0,0]
            if i == 3:
                t = [0,0,0,1,0]
            if i == 4:
                t = [0,0,0,0,1]
            y_data.append(t)
            x_data.append(cur_data)
    print_info("read done!")
    x_data = np.array(x_data,dtype = np.float32)
    y_data = np.array(y_data,dtype = np.float32)
    return x_data,y_data