Пример #1
0
def train(model_dir, sensitivities_file, eval_metric_loss):
    # 下载和解压视盘分割数据集
    optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
    pdx.utils.download_and_decompress(optic_dataset, path='./')

    # 定义训练和验证时的transforms
    train_transforms = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ResizeRangeScaling(),
        transforms.RandomPaddingCrop(crop_size=512),
        transforms.Normalize()
    ])
    eval_transforms = transforms.Compose([
        transforms.ResizeByLong(long_size=512),
        transforms.Padding(target_size=512),
        transforms.Normalize()
    ])

    # 定义训练和验证所用的数据集
    train_dataset = pdx.datasets.SegDataset(
        data_dir='optic_disc_seg',
        file_list='optic_disc_seg/train_list.txt',
        label_list='optic_disc_seg/labels.txt',
        transforms=train_transforms,
        shuffle=True)
    eval_dataset = pdx.datasets.SegDataset(
        data_dir='optic_disc_seg',
        file_list='optic_disc_seg/val_list.txt',
        label_list='optic_disc_seg/labels.txt',
        transforms=eval_transforms)

    if model_dir is None:
        # 使用coco数据集上的预训练权重
        pretrain_weights = "COCO"
    else:
        assert os.path.isdir(model_dir), "Path {} is not a directory".format(
            model_dir)
        pretrain_weights = model_dir
    save_dir = "output/unet"
    if sensitivities_file is not None:
        if sensitivities_file != 'DEFAULT':
            assert os.path.exists(
                sensitivities_file), "Path {} not exist".format(
                    sensitivities_file)
        save_dir = "output/unet_prune"

    num_classes = len(train_dataset.labels)
    model = pdx.seg.UNet(num_classes=num_classes)
    model.train(
        num_epochs=20,
        train_dataset=train_dataset,
        train_batch_size=4,
        eval_dataset=eval_dataset,
        learning_rate=0.01,
        pretrain_weights=pretrain_weights,
        save_dir=save_dir,
        use_vdl=True,
        sensitivities_file=sensitivities_file,
        eval_metric_loss=eval_metric_loss)
Пример #2
0
 def __init__(self):
     super(BarometerReader, self).__init__()
     self.detector = pdx.load_model(
         os.path.join(self.directory, 'meter_det_inference_model'))
     self.segmenter = pdx.load_model(
         os.path.join(self.directory, 'meter_seg_inference_model'))
     self.seg_transform = T.Compose([T.Normalize()])
Пример #3
0
 def __init__(self, detector_dir, segmenter_dir):
     if not osp.exists(detector_dir):
         raise Exception(
             "Model path {} does not exist".format(detector_dir))
     if not osp.exists(segmenter_dir):
         raise Exception(
             "Model path {} does not exist".format(segmenter_dir))
     self.detector = pdx.load_model(detector_dir)
     self.segmenter = pdx.load_model(segmenter_dir)
     # Because we will resize images with (METER_SHAPE, METER_SHAPE) before fed into the segmenter,
     # here the transform is composed of normalization only.
     self.seg_transforms = transforms.Compose([transforms.Normalize()])
     self.ocr = OcrInference()
Пример #4
0
import paddlex as pdx
from paddlex.seg import transforms

# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
    transforms.RandomPaddingCrop(crop_size=769),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.RandomVerticalFlip(prob=0.5),
    transforms.RandomBlur(prob=0.5),
    transforms.RandomRotate(rotate_range=35),
    transforms.RandomDistort(brightness_prob=0.5,
                             contrast_prob=0.5,
                             saturation_prob=0.5,
                             hue_prob=0.5),
    transforms.Normalize()
])

eval_transforms = transforms.Compose(
    [transforms.Padding(target_size=769),
     transforms.Normalize()])

#定义数据集
train_dataset = pdx.datasets.SegDataset(data_dir='dataset',
                                        file_list='dataset/train_list.txt',
                                        label_list='dataset/labels.txt',
                                        transforms=train_transforms,
                                        shuffle=True)
eval_dataset = pdx.datasets.SegDataset(data_dir='dataset',
                                       file_list='dataset/val_list.txt',
                                       label_list='dataset/labels.txt',
Пример #5
0
import os
import paddlex as pdx

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from paddlex.seg import transforms
import imgaug.augmenters as iaa

train_transforms = transforms.Compose([
    transforms.Resize(target_size=300),
    transforms.RandomPaddingCrop(crop_size=256),
    transforms.RandomBlur(prob=0.1),
    transforms.RandomRotate(rotate_range=15),
    # transforms.RandomDistort(brightness_range=0.5),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize()
])
eval_transforms = transforms.Compose(
    [transforms.Resize(256), transforms.Normalize()])

# !unzip data/data55723/img_testA.zip
# !unzip data/data55723/train_data.zip

# !unzip train_data/lab_train.zip
# !unzip train_data/img_train.zip

import numpy as np

datas = []
image_base = 'img_train'
annos_base = 'lab_train'
Пример #6
0
# 选择使用0号卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import paddlex as pdx
from paddlex.seg import transforms

# 下载和解压视盘分割数据集
optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx.utils.download_and_decompress(optic_dataset, path='./')

# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(target_size=512),
    transforms.RandomPaddingCrop(crop_size=500),
    transforms.Normalize()
])

eval_transforms = transforms.Compose(
    [transforms.Resize(512), transforms.Normalize()])

# 定义训练和验证所用的数据集
train_dataset = pdx.datasets.SegDataset(
    data_dir='optic_disc_seg',
    file_list='optic_disc_seg/train_list.txt',
    label_list='optic_disc_seg/labels.txt',
    transforms=train_transforms,
    shuffle=True)
eval_dataset = pdx.datasets.SegDataset(data_dir='optic_disc_seg',
                                       file_list='optic_disc_seg/val_list.txt',
                                       label_list='optic_disc_seg/labels.txt',