예제 #1
0
def generate_cross_validate_dataset():
    cross_validate_dataset = get_parsed_dataset(
        tfrecord_name=cross_validate_tfrecord)
    test_dataset = get_parsed_dataset(tfrecord_name=test_tfrecord)

    cross_validate_dataset_count = get_the_length_of_dataset(
        cross_validate_dataset)
    test_count = get_the_length_of_dataset(test_dataset)

    cross_validate_dataset = cross_validate_dataset.batch(
        batch_size=BATCH_SIZE)
    test_dataset = test_dataset.batch(batch_size=BATCH_SIZE)
    batch_num = math.ceil(cross_validate_dataset_count / BATCH_SIZE)
    boundary = math.ceil(batch_num / k_fold)
    cross_validate_dataset_list = []
    for i in range(k_fold):
        cross_validate_dataset_item = []
        for index, cross_validate_item in enumerate(cross_validate_dataset):
            if i * boundary <= index < (i + 1) * boundary:
                cross_validate_dataset_item.append(cross_validate_item)
            else:
                continue
        # print('batch个数: {}'.format(len(cross_validate_dataset_item)))
        cross_validate_dataset_list.append(cross_validate_dataset_item)
    # print('交叉验证集合的个数: {}, 每个验证集合的batch个数: {}, 每个batch的类型: {}'
    #       .format(len(cross_validate_dataset_list), len(cross_validate_dataset_list[0]),
    #               type(cross_validate_dataset_list[0][0])))

    return cross_validate_dataset_list, test_dataset, cross_validate_dataset_count, test_count
def generate_datasets():
    train_dataset = get_parsed_dataset(tfrecord_name=train_tfrecord)
    valid_dataset = get_parsed_dataset(tfrecord_name=valid_tfrecord)
    test_dataset = get_parsed_dataset(tfrecord_name=test_tfrecord)

    train_count = get_the_length_of_dataset(train_dataset)
    valid_count = get_the_length_of_dataset(valid_dataset)
    test_count = get_the_length_of_dataset(test_dataset)

    return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count
예제 #3
0
def generate_datasets():
    train_dataset = get_parsed_dataset(tfrecord_name=train_tfrecord)
    valid_dataset = get_parsed_dataset(tfrecord_name=valid_tfrecord)
    test_dataset = get_parsed_dataset(tfrecord_name=test_tfrecord)

    train_count = get_the_length_of_dataset(train_dataset)
    valid_count = get_the_length_of_dataset(valid_dataset)
    test_count = get_the_length_of_dataset(test_dataset)
    # read the dataset in the form of batch
    train_dataset = train_dataset.batch(batch_size=BATCH_SIZE)
    valid_dataset = valid_dataset.batch(batch_size=BATCH_SIZE)
    test_dataset = test_dataset.batch(batch_size=BATCH_SIZE)
    return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count
예제 #4
0
def test_image_standard():
    train_dataset = get_parsed_dataset('dataset/train.tfrecord')
    train_dataset_batch = train_dataset.batch(batch_size=4)
    index = 0
    for train_sample in train_dataset_batch:
        image_tensor = []
        image_raw_list = train_sample['image_raw'].numpy()
        image_label = train_sample['label'].numpy()
        for image_raw_item in image_raw_list:
            # image_data = tf.io.decode_image(contents=image_raw_item, channels=3, dtype=tf.dtypes.float32)
            image_data = load_and_preprocess_image(image_raw_item,
                                                   data_augmentation=True)
            # plt.figure()
            # plt.imshow(image_data)
            # plt.show()
            image_tensor.append(image_data)
            print('{} : max:{}, min:{}, mean:{}, , shape:{}'.format(
                index, np.max(image_data), np.min(image_data),
                np.mean(image_data), image_data.shape))
            # print('shape:{}'.format(image_data.shape))
            index = index + 1
        images = tf.stack(image_tensor, axis=0)
        print('{} : {}'.format(image_label, images.shape))