コード例 #1
0
from collections import OrderedDict

import dataset

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--src-model', '-s', type=str, help='source model file')
    parser.add_argument('--dst-model', '-d', type=str, help='destination model file')
    parser.add_argument('--num-classes', type=int, default=1, help='number of classes')
    parser.add_argument('--dataset', type=str, default='', help='dataset path')
    args = parser.parse_args()
    print(args)
    
    # dataset = dataset.CustomDataset(args.dataset, 'train')
    # num_ids = dataset.max_id + 2
    dataset = dataset.HotchpotchDataset('/data/tseng/dataset/jde', './data/train.txt')
    num_ids = int(dataset.max_id + 1)
    print(num_ids)
    
    if '0.5x' in args.src_model:
        model_size = '0.5x'
    elif '1.0x' in args.src_model:
        model_size = '1.0x'
    elif '1.5x' in args.src_model:
        model_size = '1.5x'
    elif '2.0x' in args.src_model:
        model_size = '2.0x'
    
    anchors = np.random.randint(low=10, high=150, size=(12,2))
    model = shufflenetv2.ShuffleNetV2(anchors, num_classes=args.num_classes, num_ids=num_ids, model_size=model_size)
    
コード例 #2
0
import sys
import torch
import pickle
import numpy as np
from functools import partial
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

sys.path.append('.')
import dataset as ds

boxw0 = []
boxh0 = []
cachefile = './wh.pkl'
if not os.path.isfile(cachefile):
    dataset = ds.HotchpotchDataset('/data/tseng/dataset/jde',
                                   './data/train.txt', 'shufflenetv2', False)
    in_size = torch.IntTensor([320, 576])
    collate_fn = partial(ds.collate_fn, in_size=in_size, train=False)
    data_loader = torch.utils.data.DataLoader(dataset, collate_fn=collate_fn)
    for batch, (images, targets) in enumerate(data_loader):
        n, c, h, w = images.size()
        if targets.size(0) > 0:
            tw = (targets[:, 5] * w).numpy().round().astype(np.int)
            th = (targets[:, 6] * h).numpy().round().astype(np.int)
            mask = (tw < 8) | (th < 8)
            if mask.sum() > 0:
                with open('error_box.txt', 'a') as file:
                    file.write('{}\n'.format(targets.numpy()))
                    file.write('{}x{}\n'.format(h, w))
                    file.write('{}\n'.format(tw))
                    file.write('{}\n\n\n'.format(th))
コード例 #3
0
def train(args):
    utils.make_workspace_dirs(args.workspace)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    anchors = np.loadtxt(args.anchors) if args.anchors else None
    scale_sampler = utils.TrainScaleSampler(args.in_size, args.scale_step,
                                            args.rescale_freq)
    shared_size = torch.IntTensor(args.in_size).share_memory_()
    logger = utils.get_logger(path=os.path.join(args.workspace, 'log.txt'))

    torch.backends.cudnn.benchmark = True

    dataset = ds.HotchpotchDataset(args.dataset_root, './data/train.txt',
                                   args.backbone)
    collate_fn = partial(ds.collate_fn, in_size=shared_size, train=True)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              args.batch_size,
                                              True,
                                              num_workers=args.workers,
                                              collate_fn=collate_fn,
                                              pin_memory=args.pin,
                                              drop_last=True)

    num_ids = int(dataset.max_id + 1)
    if args.backbone == 'darknet':
        model = darknet.DarkNet(anchors,
                                num_classes=args.num_classes,
                                num_ids=num_ids).to(device)
    elif args.backbone == 'shufflenetv2':
        model = shufflenetv2.ShuffleNetV2(anchors,
                                          num_classes=args.num_classes,
                                          num_ids=num_ids,
                                          model_size=args.thin,
                                          box_loss=args.box_loss,
                                          cls_loss=args.cls_loss).to(device)
    else:
        print('unknown backbone architecture!')
        sys.exit(0)
    if args.checkpoint:
        model.load_state_dict(torch.load(args.checkpoint))
    lr_min = 0.00025
    params = [p for p in model.parameters() if p.requires_grad]
    backbone_neck_params, detection_params, identity_params = grouping_model_params(
        model)
    if args.optim == 'sgd':
        # optimizer = torch.optim.SGD(params, lr=args.lr,
        #     momentum=args.momentum, weight_decay=args.weight_decay)
        optimizer = torch.optim.SGD([{
            'params': backbone_neck_params
        }, {
            'params': detection_params,
            'lr': args.lr * args.lr_coeff[1]
        }, {
            'params': identity_params,
            'lr': args.lr * args.lr_coeff[2]
        }],
                                    lr=(args.lr - lr_min),
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

    if args.freeze_bn:
        for name, param in model.named_parameters():
            if 'norm' in name:
                param.requires_grad = False
                logger.info('freeze {}'.format(name))
            else:
                param.requires_grad = True

    trainer = f'{args.workspace}/checkpoint/trainer-ckpt.pth'
    if args.resume:
        trainer_state = torch.load(trainer)
        optimizer.load_state_dict(trainer_state['optimizer'])

    def lr_lambda(batch):
        return 0.5 * math.cos(
            (batch % len(data_loader)) / (len(data_loader) - 1) * math.pi)

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    start_epoch = 0
    logger.info(args)
    logger.info('Start training from epoch {}'.format(start_epoch))
    model_path = f'{args.workspace}/checkpoint/{args.savename}-ckpt-%03d.pth'
    size = shared_size.numpy().tolist()

    for epoch in range(start_epoch, args.epochs):
        model.train()
        logger.info(('%8s%10s%10s' + '%10s' * 8) %
                    ('Epoch', 'Batch', 'SIZE', 'LBOX', 'LCLS', 'LIDE', 'LOSS',
                     'SBOX', 'SCLS', 'SIDE', 'LR'))

        rmetrics = defaultdict(float)
        optimizer.zero_grad()
        for batch, (images, targets) in enumerate(data_loader):
            for i, g in enumerate(optimizer.param_groups):
                g['lr'] += (args.lr - lr_min) * 0.5 + lr_min
            loss, metrics = model(images.to(device), targets.to(device), size)
            loss.backward()

            if args.sparsity:
                model.correct_bn_grad(args.lamb)

            num_batches = epoch * len(data_loader) + batch + 1
            if ((batch + 1) % args.accumulated_batches
                    == 0) or (batch == len(data_loader) - 1):
                optimizer.step()
                optimizer.zero_grad()

            for k, v in metrics.items():
                rmetrics[k] = (rmetrics[k] * batch + metrics[k]) / (batch + 1)

            fmt = tuple([('%g/%g') % (epoch, args.epochs), ('%g/%g') % (batch,
                len(data_loader)), ('%gx%g') % (size[0], size[1])] + \
                list(rmetrics.values()) + [optimizer.param_groups[0]['lr']])
            if batch % args.print_interval == 0:
                logger.info(('%8s%10s%10s' + '%10.3g' *
                             (len(rmetrics.values()) + 1)) % fmt)

            size = scale_sampler(num_batches)
            shared_size[0], shared_size[1] = size[0], size[1]
            lr_scheduler.step()

        torch.save(model.state_dict(), f"{model_path}" % epoch)
        torch.save(
            {
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict()
            }, trainer)

        if epoch >= args.eval_epoch:
            pass
コード例 #4
0
                        default=1,
                        help='number of classes')
    parser.add_argument('--darknet-model',
                        '-dm',
                        type=str,
                        dest='dm',
                        default='darknet.weights',
                        help='darknet-format model file')
    parser.add_argument('--load-backbone-only',
                        '-lbo',
                        dest='lbo',
                        help='only load the backbone',
                        action='store_true')
    args = parser.parse_args()

    dataset = dataset.HotchpotchDataset(args.dataset, './data/train.txt')
    num_ids = int(dataset.max_id + 1)
    print(num_ids)

    model = darknet.DarkNet(np.random.randint(0, 100, (12, 2)),
                            num_classes=args.num_classes,
                            num_ids=num_ids)

    with open(args.dm, 'rb') as file:
        major = np.fromfile(file, dtype=np.int32, count=1)
        minor = np.fromfile(file, dtype=np.int32, count=1)
        revision = np.fromfile(file, dtype=np.int32, count=1)
        seen = np.fromfile(file, dtype=np.int64, count=1)
        print(
            f'darknet model version:{major.data[0]}.{minor.data[0]}.{revision.data[0]}'
        )
コード例 #5
0
ファイル: train.py プロジェクト: fightseed/JDE
def train(args):
    utils.make_workspace_dirs(args.workspace)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    anchors = np.loadtxt(os.path.join(args.dataset, 'anchors.txt'))
    scale_sampler = utils.TrainScaleSampler(args.in_size, args.scale_step,
                                            args.rescale_freq)
    shared_size = torch.IntTensor(args.in_size).share_memory_()
    logger = utils.get_logger(path=os.path.join(args.workspace, 'log.txt'))

    torch.backends.cudnn.benchmark = True

    # dataset = ds.CustomDataset(args.dataset, 'train', args.backbone)
    dataset = ds.HotchpotchDataset('/data/tseng/dataset/jde',
                                   './data/train.txt', args.backbone)
    collate_fn = partial(ds.collate_fn, in_size=shared_size, train=True)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              args.batch_size,
                                              True,
                                              num_workers=args.workers,
                                              collate_fn=collate_fn,
                                              pin_memory=args.pin,
                                              drop_last=True)

    # num_ids = dataset.max_id + 2
    num_ids = int(dataset.max_id + 1)
    if args.backbone == 'darknet':
        model = darknet.DarkNet(anchors,
                                num_classes=args.num_classes,
                                num_ids=num_ids).to(device)
    elif args.backbone == 'shufflenetv2':
        model = shufflenetv2.ShuffleNetV2(anchors,
                                          num_classes=args.num_classes,
                                          num_ids=num_ids,
                                          model_size=args.thin).to(device)
    else:
        print('unknown backbone architecture!')
        sys.exit(0)
    if args.checkpoint:
        model.load_state_dict(torch.load(args.checkpoint))

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

    if args.freeze_bn:
        for name, param in model.named_parameters():
            if 'norm' in name:
                param.requires_grad = False
                logger.info('freeze {}'.format(name))
            else:
                param.requires_grad = True

    trainer = f'{args.workspace}/checkpoint/trainer-ckpt.pth'
    if args.resume:
        trainer_state = torch.load(trainer)
        optimizer.load_state_dict(trainer_state['optimizer'])

    if -1 in args.milestones:
        args.milestones = [int(args.epochs * 0.5), int(args.epochs * 0.75)]
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=args.lr_gamma)

    start_epoch = 0
    if args.resume:
        start_epoch = trainer_state['epoch'] + 1
        lr_scheduler.load_state_dict(trainer_state['lr_scheduler'])

    logger.info(args)
    logger.info('Start training from epoch {}'.format(start_epoch))
    model_path = f'{args.workspace}/checkpoint/{args.savename}-ckpt-%03d.pth'
    size = shared_size.numpy().tolist()

    for epoch in range(start_epoch, args.epochs):
        model.train()
        logger.info(('%8s%10s%10s' + '%10s' * 8) %
                    ('Epoch', 'Batch', 'SIZE', 'LBOX', 'LCLS', 'LIDE', 'LOSS',
                     'SB', 'SC', 'SI', 'LR'))

        rmetrics = defaultdict(float)
        optimizer.zero_grad()
        for batch, (images, targets) in enumerate(data_loader):
            warmup = min(args.warmup, len(data_loader))
            if epoch == 0 and batch <= warmup:
                lr = args.lr * (batch / warmup)**4
                for g in optimizer.param_groups:
                    g['lr'] = lr

            loss, metrics = model(images.to(device), targets.to(device), size)
            loss.backward()

            if args.sparsity:
                model.correct_bn_grad(args.lamb)

            num_batches = epoch * len(data_loader) + batch + 1
            if ((batch + 1) % args.accumulated_batches
                    == 0) or (batch == len(data_loader) - 1):
                optimizer.step()
                optimizer.zero_grad()

            for k, v in metrics.items():
                rmetrics[k] = (rmetrics[k] * batch + metrics[k]) / (batch + 1)

            fmt = tuple([('%g/%g') % (epoch, args.epochs), ('%g/%g') % (batch,
                len(data_loader)), ('%gx%g') % (size[0], size[1])] + \
                list(rmetrics.values()) + [optimizer.param_groups[0]['lr']])
            if batch % args.print_interval == 0:
                logger.info(('%8s%10s%10s' + '%10.3g' *
                             (len(rmetrics.values()) + 1)) % fmt)

            size = scale_sampler(num_batches)
            shared_size[0], shared_size[1] = size[0], size[1]

        torch.save(model.state_dict(), f"{model_path}" % epoch)
        torch.save(
            {
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict()
            }, trainer)

        if epoch >= args.eval_epoch:
            pass
        lr_scheduler.step()