def train(model_dir, sensitivities_file, eval_metric_loss):
    # 定义训练和验证时的transforms
    train_transforms = transforms.Compose([
        transforms.MixupImage(mixup_epoch=250),
        transforms.RandomDistort(),
        transforms.RandomExpand(),
        transforms.RandomCrop(),
        transforms.Resize(target_size=608, interp='RANDOM'),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize()
    ])
    eval_transforms = transforms.Compose([
        transforms.Resize(target_size=608, interp='CUBIC'),
        transforms.Normalize()
    ])

    # 定义训练和验证所用的数据集
    train_dataset = pdx.datasets.VOCDetection(
        data_dir='dataset',
        file_list='dataset/train_list.txt',
        label_list='dataset/labels.txt',
        transforms=train_transforms,
        shuffle=True)
    eval_dataset = pdx.datasets.VOCDetection(data_dir='dataset',
                                             file_list='dataset/val_list.txt',
                                             label_list='dataset/labels.txt',
                                             transforms=eval_transforms)

    if model_dir is None:
        # 使用imagenet数据集上的预训练权重
        pretrain_weights = "IMAGENET"
    else:
        assert os.path.isdir(model_dir), "Path {} is not a directory".format(
            model_dir)
        pretrain_weights = model_dir
    save_dir = "output/yolov3_mobile"
    if sensitivities_file is not None:
        if sensitivities_file != 'DEFAULT':
            assert os.path.exists(
                sensitivities_file), "Path {} not exist".format(
                    sensitivities_file)
        save_dir = "output/yolov3_mobile_prune"

    num_classes = len(train_dataset.labels)
    model = pdx.det.YOLOv3(num_classes=num_classes)
    model.train(num_epochs=400,
                train_dataset=train_dataset,
                train_batch_size=10,
                eval_dataset=eval_dataset,
                learning_rate=0.0001,
                lr_decay_epochs=[310, 350],
                pretrain_weights=pretrain_weights,
                save_dir=save_dir,
                use_vdl=True,
                sensitivities_file=sensitivities_file,
                eval_metric_loss=eval_metric_loss)
示例#2
0
base = './data/'
import os

import paddlex as pdx
from paddlex.det import transforms

train_transforms = transforms.Compose([
    transforms.MixupImage(mixup_epoch=250),
    transforms.RandomDistort(),
    transforms.RandomExpand(),
    transforms.RandomCrop(),
    transforms.Resize(target_size=512, interp='RANDOM'),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(),
])

eval_transforms = transforms.Compose([
    transforms.Resize(target_size=512, interp='CUBIC'),
    transforms.Normalize(),
])
train_dataset = pdx.datasets.VOCDetection(data_dir=base,
                                          file_list=os.path.join(
                                              base, 'train.txt'),
                                          label_list='./data/labels.txt',
                                          transforms=train_transforms,
                                          shuffle=True)
eval_dataset = pdx.datasets.VOCDetection(data_dir=base,
                                         file_list=os.path.join(
                                             base, 'valid.txt'),
                                         transforms=eval_transforms,
                                         label_list='./data/labels.txt')