def train(model_dir, sensitivities_file, eval_metric_loss):
    # 定义训练和验证时的transforms
    train_transforms = transforms.Compose([
        transforms.MixupImage(mixup_epoch=250),
        transforms.RandomDistort(),
        transforms.RandomExpand(),
        transforms.RandomCrop(),
        transforms.Resize(target_size=608, interp='RANDOM'),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize()
    ])
    eval_transforms = transforms.Compose([
        transforms.Resize(target_size=608, interp='CUBIC'),
        transforms.Normalize()
    ])

    # 定义训练和验证所用的数据集
    train_dataset = pdx.datasets.VOCDetection(
        data_dir='dataset',
        file_list='dataset/train_list.txt',
        label_list='dataset/labels.txt',
        transforms=train_transforms,
        shuffle=True)
    eval_dataset = pdx.datasets.VOCDetection(data_dir='dataset',
                                             file_list='dataset/val_list.txt',
                                             label_list='dataset/labels.txt',
                                             transforms=eval_transforms)

    if model_dir is None:
        # 使用imagenet数据集上的预训练权重
        pretrain_weights = "IMAGENET"
    else:
        assert os.path.isdir(model_dir), "Path {} is not a directory".format(
            model_dir)
        pretrain_weights = model_dir
    save_dir = "output/yolov3_mobile"
    if sensitivities_file is not None:
        if sensitivities_file != 'DEFAULT':
            assert os.path.exists(
                sensitivities_file), "Path {} not exist".format(
                    sensitivities_file)
        save_dir = "output/yolov3_mobile_prune"

    num_classes = len(train_dataset.labels)
    model = pdx.det.YOLOv3(num_classes=num_classes)
    model.train(num_epochs=400,
                train_dataset=train_dataset,
                train_batch_size=10,
                eval_dataset=eval_dataset,
                learning_rate=0.0001,
                lr_decay_epochs=[310, 350],
                pretrain_weights=pretrain_weights,
                save_dir=save_dir,
                use_vdl=True,
                sensitivities_file=sensitivities_file,
                eval_metric_loss=eval_metric_loss)
Esempio n. 2
0
base = './data/'
import os

import paddlex as pdx
from paddlex.det import transforms

train_transforms = transforms.Compose([
    transforms.MixupImage(mixup_epoch=250),
    transforms.RandomDistort(),
    transforms.RandomExpand(),
    transforms.RandomCrop(),
    transforms.Resize(target_size=512, interp='RANDOM'),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(),
])

eval_transforms = transforms.Compose([
    transforms.Resize(target_size=512, interp='CUBIC'),
    transforms.Normalize(),
])
train_dataset = pdx.datasets.VOCDetection(data_dir=base,
                                          file_list=os.path.join(
                                              base, 'train.txt'),
                                          label_list='./data/labels.txt',
                                          transforms=train_transforms,
                                          shuffle=True)
eval_dataset = pdx.datasets.VOCDetection(data_dir=base,
                                         file_list=os.path.join(
                                             base, 'valid.txt'),
                                         transforms=eval_transforms,
                                         label_list='./data/labels.txt')
Esempio n. 3
0
# eval_transforms = transforms.Compose([
#     transforms.Resize([1920, 1080]), transforms.Normalize()
# ])

# train_transforms = t.Compose([t.ComposedYOLOv3Transforms("train")])
# eval_transforms = t.Compose([t.ComposedYOLOv3Transforms("eval")])

width = 255
height = 255
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
epoch_num = 100

train_transforms = t.Compose([
    t.RandomHorizontalFlip(),
    t.RandomExpand(),
    t.RandomDistort(),
    # t.MixupImage(mixup_epoch=int(epoch_num * 0.5)),
    t.Resize(target_size=width, interp='RANDOM'),
    t.Normalize(mean=mean, std=std),
])
# 定义训练和验证所用的数据集
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-vocdetection
train_dataset = pdx.datasets.CocoDetection(
    data_dir='/home/aistudio/data/data67498/DatasetId_153862_1611403574/Images',
    ann_file=
    '/home/aistudio/data/data67498/DatasetId_153862_1611403574/Annotations/coco_info.json',
    transforms=train_transforms,
    num_workers=8,
    buffer_size=256,
    parallel_method='process',