コード例 #1
0
ファイル: mobileNetV3_ssld.py プロジェクト: IanVzs/demo_test
def train():
    from paddlex.cls import transforms
    train_transforms = transforms.Compose([
        transforms.RandomCrop(crop_size=224),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize()
    ])
    eval_transforms = transforms.Compose([
        transforms.ResizeByShort(short_size=256),
        transforms.CenterCrop(crop_size=224),
        transforms.Normalize()
    ])

    train_dataset = pdx.datasets.ImageNet(
        data_dir=f'{data_dir}/mini_imagenet_veg',
        file_list=f'{data_dir}/mini_imagenet_veg/train_list.txt',
        label_list=f'{data_dir}/mini_imagenet_veg/labels.txt',
        transforms=train_transforms)
    eval_dataset = pdx.datasets.ImageNet(
        data_dir=f'{data_dir}/mini_imagenet_veg',
        file_list=f'{data_dir}/mini_imagenet_veg/val_list.txt',
        label_list=f'{data_dir}/mini_imagenet_veg/labels.txt',
        transforms=eval_transforms)

    num_classes = len(train_dataset.labels)
    model = pdx.cls.MobileNetV3_large_ssld(num_classes=num_classes)
    model.train(num_epochs=12,
                train_dataset=train_dataset,
                train_batch_size=32,
                eval_dataset=eval_dataset,
                lr_decay_epochs=[6, 8],
                save_interval_epochs=1,
                learning_rate=0.00625,
                save_dir=f'{output_dir}/output/mobilenetv3_large_ssld',
                use_vdl=True)
コード例 #2
0
ファイル: mobileNetV2.py プロジェクト: IanVzs/demo_test
def train():
    from paddlex.cls import transforms
    train_transforms = transforms.Compose([
        transforms.RandomCrop(crop_size=224),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize()
    ])
    eval_transforms = transforms.Compose([
        transforms.ResizeByShort(short_size=256),
        transforms.CenterCrop(crop_size=224),
        transforms.Normalize()
    ])

    train_dataset = pdx.datasets.ImageNet(
        data_dir=f'{data_dir}/vegetables_cls',
        file_list=f'{data_dir}/vegetables_cls/train_list.txt',
        label_list=f'{data_dir}/vegetables_cls/labels.txt',
        transforms=train_transforms,
        shuffle=True)
    eval_dataset = pdx.datasets.ImageNet(
        data_dir=f'{data_dir}/vegetables_cls',
        file_list=f'{data_dir}/vegetables_cls/val_list.txt',
        label_list=f'{data_dir}/vegetables_cls/labels.txt',
        transforms=eval_transforms)

    num_classes = len(train_dataset.labels)
    model = pdx.cls.MobileNetV2(num_classes=num_classes)
    model.train(num_epochs=10,
                train_dataset=train_dataset,
                train_batch_size=32,
                eval_dataset=eval_dataset,
                lr_decay_epochs=[4, 6, 8],
                save_interval_epochs=1,
                learning_rate=0.025,
                save_dir=f'{output_dir}/mobilenetv2',
                use_vdl=True)
コード例 #3
0
ファイル: mobilenetv2.py プロジェクト: yzl19940819/test
def train(model_dir=None, sensitivities_file=None, eval_metric_loss=0.05):
    # 下载和解压蔬菜分类数据集
    veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
    pdx.utils.download_and_decompress(veg_dataset, path='./')

    # 定义训练和验证时的transforms
    train_transforms = transforms.Compose([
        transforms.RandomCrop(crop_size=224),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize()
    ])
    eval_transforms = transforms.Compose([
        transforms.ResizeByShort(short_size=256),
        transforms.CenterCrop(crop_size=224),
        transforms.Normalize()
    ])

    # 定义训练和验证所用的数据集
    train_dataset = pdx.datasets.ImageNet(
        data_dir='vegetables_cls',
        file_list='vegetables_cls/train_list.txt',
        label_list='vegetables_cls/labels.txt',
        transforms=train_transforms,
        shuffle=True)
    eval_dataset = pdx.datasets.ImageNet(
        data_dir='vegetables_cls',
        file_list='vegetables_cls/val_list.txt',
        label_list='vegetables_cls/labels.txt',
        transforms=eval_transforms)

    num_classes = len(train_dataset.labels)
    model = pdx.cls.MobileNetV2(num_classes=num_classes)

    if model_dir is None:
        # 使用imagenet数据集预训练模型权重
        pretrain_weights = "IMAGENET"
    else:
        # 使用传入的model_dir作为预训练模型权重
        assert os.path.isdir(model_dir), "Path {} is not a directory".format(
            model_dir)
        pretrain_weights = model_dir

    save_dir = './output/mobilenetv2'
    if sensitivities_file is not None:
        # DEFAULT 指使用模型预置的参数敏感度信息作为裁剪依据
        if sensitivities_file != "DEFAULT":
            assert os.path.exists(
                sensitivities_file), "Path {} not exist".format(
                    sensitivities_file)
        save_dir = './output/mobilenetv2_prune'

    model.train(
        num_epochs=10,
        train_dataset=train_dataset,
        train_batch_size=32,
        eval_dataset=eval_dataset,
        lr_decay_epochs=[4, 6, 8],
        learning_rate=0.025,
        pretrain_weights=pretrain_weights,
        save_dir=save_dir,
        use_vdl=True,
        sensitivities_file=sensitivities_file,
        eval_metric_loss=eval_metric_loss)
コード例 #4
0
ファイル: gls_clas.py プロジェクト: randornot/plane
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from paddlex.cls import transforms
import paddlex as pdx

# 下载和解压蔬菜分类数据集
# veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
# pdx.utils.download_and_decompress(veg_dataset, path='./')

# 定义训练和验证时的transforms
# API说明https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/cls_transforms.html
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotate(),
    transforms.RandomDistort(),
    transforms.Normalize()
])
# eval_transforms = transforms.Compose([
#     transforms.Normalize()
# ])

# 定义训练和验证所用的数据集
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-imagenet
train_dataset = pdx.datasets.ImageNet(
    data_dir='/home/aistudio/data/data67498/train',
    file_list='/home/aistudio/data/data67498/train/train_list.txt',
    label_list='/home/aistudio/data/data67498/train/labels.txt',
    transforms=train_transforms,
    shuffle=True)
# eval_dataset = pdx.datasets.ImageNet(
#     data_dir='vegetables_cls',