def generate_cross_validate_dataset(): cross_validate_dataset = get_parsed_dataset( tfrecord_name=cross_validate_tfrecord) test_dataset = get_parsed_dataset(tfrecord_name=test_tfrecord) cross_validate_dataset_count = get_the_length_of_dataset( cross_validate_dataset) test_count = get_the_length_of_dataset(test_dataset) cross_validate_dataset = cross_validate_dataset.batch( batch_size=BATCH_SIZE) test_dataset = test_dataset.batch(batch_size=BATCH_SIZE) batch_num = math.ceil(cross_validate_dataset_count / BATCH_SIZE) boundary = math.ceil(batch_num / k_fold) cross_validate_dataset_list = [] for i in range(k_fold): cross_validate_dataset_item = [] for index, cross_validate_item in enumerate(cross_validate_dataset): if i * boundary <= index < (i + 1) * boundary: cross_validate_dataset_item.append(cross_validate_item) else: continue # print('batch个数: {}'.format(len(cross_validate_dataset_item))) cross_validate_dataset_list.append(cross_validate_dataset_item) # print('交叉验证集合的个数: {}, 每个验证集合的batch个数: {}, 每个batch的类型: {}' # .format(len(cross_validate_dataset_list), len(cross_validate_dataset_list[0]), # type(cross_validate_dataset_list[0][0]))) return cross_validate_dataset_list, test_dataset, cross_validate_dataset_count, test_count
def generate_datasets(): train_dataset = get_parsed_dataset(tfrecord_name=train_tfrecord) valid_dataset = get_parsed_dataset(tfrecord_name=valid_tfrecord) test_dataset = get_parsed_dataset(tfrecord_name=test_tfrecord) train_count = get_the_length_of_dataset(train_dataset) valid_count = get_the_length_of_dataset(valid_dataset) test_count = get_the_length_of_dataset(test_dataset) return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count
def generate_datasets(): train_dataset = get_parsed_dataset(tfrecord_name=train_tfrecord) valid_dataset = get_parsed_dataset(tfrecord_name=valid_tfrecord) test_dataset = get_parsed_dataset(tfrecord_name=test_tfrecord) train_count = get_the_length_of_dataset(train_dataset) valid_count = get_the_length_of_dataset(valid_dataset) test_count = get_the_length_of_dataset(test_dataset) # read the dataset in the form of batch train_dataset = train_dataset.batch(batch_size=BATCH_SIZE) valid_dataset = valid_dataset.batch(batch_size=BATCH_SIZE) test_dataset = test_dataset.batch(batch_size=BATCH_SIZE) return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count
def test_image_standard(): train_dataset = get_parsed_dataset('dataset/train.tfrecord') train_dataset_batch = train_dataset.batch(batch_size=4) index = 0 for train_sample in train_dataset_batch: image_tensor = [] image_raw_list = train_sample['image_raw'].numpy() image_label = train_sample['label'].numpy() for image_raw_item in image_raw_list: # image_data = tf.io.decode_image(contents=image_raw_item, channels=3, dtype=tf.dtypes.float32) image_data = load_and_preprocess_image(image_raw_item, data_augmentation=True) # plt.figure() # plt.imshow(image_data) # plt.show() image_tensor.append(image_data) print('{} : max:{}, min:{}, mean:{}, , shape:{}'.format( index, np.max(image_data), np.min(image_data), np.mean(image_data), image_data.shape)) # print('shape:{}'.format(image_data.shape)) index = index + 1 images = tf.stack(image_tensor, axis=0) print('{} : {}'.format(image_label, images.shape))