예제 #1
0
파일: train.py 프로젝트: intjun/JDE
def train(args):
    utils.make_workspace_dirs(args.workspace)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    anchors = np.loadtxt(os.path.join(args.dataset, 'anchors.txt'))
    scale_sampler = utils.TrainScaleSampler(args.in_size, args.scale_step,
                                            args.rescale_freq)
    shared_size = torch.IntTensor(args.in_size).share_memory_()
    logger = utils.get_logger(path=os.path.join(args.workspace, 'log.txt'))

    torch.backends.cudnn.benchmark = True

    dataset = ds.CustomDataset(args.dataset, 'train')
    collate_fn = partial(ds.collate_fn, in_size=shared_size, train=True)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              args.batch_size,
                                              True,
                                              num_workers=args.workers,
                                              collate_fn=collate_fn,
                                              pin_memory=args.pin,
                                              drop_last=True)

    num_ids = dataset.max_id + 2
    if args.backbone == 'darknet':
        model = darknet.DarkNet(anchors,
                                num_classes=args.num_classes,
                                num_ids=num_ids).to(device)
    elif args.backbone == 'shufflenetv2':
        model = shufflenetv2.ShuffleNetV2(anchors,
                                          num_classes=args.num_classes,
                                          num_ids=num_ids,
                                          model_size=args.thin).to(device)
    else:
        print('unknown backbone architecture!')
        sys.exit(0)
    if args.checkpoint:
        model.load_state_dict(torch.load(args.checkpoint))

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

    if args.freeze_bn:
        for name, param in model.named_parameters():
            if 'norm' in name:
                param.requires_grad = False
                logger.info('freeze {}'.format(name))
            else:
                param.requires_grad = True

    trainer = f'{args.workspace}/checkpoint/trainer-ckpt.pth'
    if args.resume:
        trainer_state = torch.load(trainer)
        optimizer.load_state_dict(trainer_state['optimizer'])

    if -1 in args.milestones:
        args.milestones = [int(args.epochs * 0.5), int(args.epochs * 0.75)]
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=args.lr_gamma)

    start_epoch = 0
    if args.resume:
        start_epoch = trainer_state['epoch'] + 1
        lr_scheduler.load_state_dict(trainer_state['lr_scheduler'])

    logger.info(args)
    logger.info('Start training from epoch {}'.format(start_epoch))
    model_path = f'{args.workspace}/checkpoint/{args.savename}-ckpt-%03d.pth'
    size = shared_size.numpy().tolist()

    for epoch in range(start_epoch, args.epochs):
        model.train()
        logger.info(('%8s%10s%10s' + '%10s' * 8) %
                    ('Epoch', 'Batch', 'SIZE', 'LBOX', 'LCLS', 'LIDE', 'LOSS',
                     'SB', 'SC', 'SI', 'LR'))

        rmetrics = defaultdict(float)
        optimizer.zero_grad()
        for batch, (images, targets) in enumerate(data_loader):
            warmup = min(args.warmup, len(data_loader))
            if epoch == 0 and batch <= warmup:
                lr = args.lr * (batch / warmup)**4
                for g in optimizer.param_groups:
                    g['lr'] = lr

            loss, metrics = model(images.to(device), targets.to(device), size)
            loss.backward()

            if args.sparsity:
                model.correct_bn_grad(args.lamb)

            num_batches = epoch * len(data_loader) + batch + 1
            if ((batch + 1) % args.accumulated_batches
                    == 0) or (batch == len(data_loader) - 1):
                optimizer.step()
                optimizer.zero_grad()

            for k, v in metrics.items():
                rmetrics[k] = (rmetrics[k] * batch + metrics[k]) / (batch + 1)

            fmt = tuple([('%g/%g') % (epoch, args.epochs), ('%g/%g') % (batch,
                len(data_loader)), ('%gx%g') % (size[0], size[1])] + \
                list(rmetrics.values()) + [optimizer.param_groups[0]['lr']])
            if batch % args.print_interval == 0:
                logger.info(('%8s%10s%10s' + '%10.3g' *
                             (len(rmetrics.values()) + 1)) % fmt)

            size = scale_sampler(num_batches)
            shared_size[0], shared_size[1] = size[0], size[1]

        torch.save(model.state_dict(), f"{model_path}" % epoch)
        torch.save(
            {
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict()
            }, trainer)

        if epoch >= args.eval_epoch:
            pass
        lr_scheduler.step()
예제 #2
0
def main(args):
    try:
        mp.set_start_method('spawn')
    except RuntimeError:
        pass
    
    utils.make_workspace_dirs(args.workspace)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    in_size = [int(insz) for insz in args.in_size.split(',')]
    scale_step = [int(ss) for ss in args.scale_step.split(',')]
    anchors = np.loadtxt(os.path.join(args.dataset, 'anchors.txt'))
    scale_sampler = utils.TrainScaleSampler(scale_step, args.rescale_freq)
    shared_size = torch.IntTensor(in_size).share_memory_()

    dataset_train = ds.CustomDataset(args.dataset, 'train')
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=partial(ds.collate_fn, in_size=shared_size, train=True),
        pin_memory=args.pin)
    
    dataset_valid = ds.CustomDataset(args.dataset, 'test')
    data_loader_valid = torch.utils.data.DataLoader(
        dataset=dataset_valid,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        collate_fn=partial(ds.collate_fn, in_size=torch.IntTensor(in_size), train=False),
        pin_memory=args.pin)

    if args.checkpoint:
        print(f'load {args.checkpoint}')
        model = torch.load(args.checkpoint).to(device)
    else:
        print('please set fine tune model first!')
        return
    
    criterion = yolov3.YOLOv3Loss(args.num_classes, anchors)
    decoder = yolov3.YOLOv3EvalDecoder(in_size, args.num_classes, anchors)
    if args.test_only:
        mAP = eval.evaluate(model, data_loader_valid, device, args.num_classes)
        print(f'mAP of current model on validation dataset:%.2f%%' % (mAP * 100))
        return
    
    params = [p for p in model.parameters() if p.requires_grad]
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)

    if args.resume:
        trainer_state = torch.load(f'{args.workspace}/checkpoint/trainer-ckpt.pth')
        optimizer.load_state_dict(trainer_state['optimizer'])
 
    milestones = [int(ms) for ms in args.milestones.split(',')]
    def lr_lambda(iter):
        if iter < args.warmup:
            return pow(iter / args.warmup, 4)
        factor = 1
        for i in milestones:
            factor *= pow(args.lr_gamma, int(iter > i))
        return factor

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    if args.resume:
        start_epoch = trainer_state['epoch'] + 1
        lr_scheduler.load_state_dict(trainer_state['lr_scheduler'])
    else:
        start_epoch = 0
    print(f'Start training from epoch {start_epoch}')

    best_mAP = 0
    for epoch in range(start_epoch, args.epochs):
        msgs = train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, epoch, args.interval, shared_size, scale_sampler, device, args.sparsity, args.lamb)
        utils.print_training_message(epoch + 1, msgs, args.batch_size)
        torch.save(model, f"{args.workspace}/checkpoint/{args.savename}-ckpt-%03d.pth" % epoch)
        torch.save({
            'epoch' : epoch,
            'optimizer' : optimizer.state_dict(),
            'lr_scheduler' : lr_scheduler.state_dict()}, f'{args.workspace}/checkpoint/trainer-ckpt.pth')
        
        if epoch >= args.eval_epoch:
            mAP = eval.evaluate(model, decoder, data_loader_valid, device, args.num_classes)
            with open(f'{args.workspace}/log/mAP.txt', 'a') as file:
                file.write(f'{epoch} {mAP}\n')
                file.close()
            print(f'Current mAP:%.2f%%' % (mAP * 100))
예제 #3
0
def train(args):
    utils.make_workspace_dirs(args.workspace)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    anchors = np.loadtxt(args.anchors) if args.anchors else None
    scale_sampler = utils.TrainScaleSampler(args.in_size, args.scale_step,
                                            args.rescale_freq)
    shared_size = torch.IntTensor(args.in_size).share_memory_()
    logger = utils.get_logger(path=os.path.join(args.workspace, 'log.txt'))

    torch.backends.cudnn.benchmark = True

    dataset = ds.HotchpotchDataset(args.dataset_root, './data/train.txt',
                                   args.backbone)
    collate_fn = partial(ds.collate_fn, in_size=shared_size, train=True)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              args.batch_size,
                                              True,
                                              num_workers=args.workers,
                                              collate_fn=collate_fn,
                                              pin_memory=args.pin,
                                              drop_last=True)

    num_ids = int(dataset.max_id + 1)
    if args.backbone == 'darknet':
        model = darknet.DarkNet(anchors,
                                num_classes=args.num_classes,
                                num_ids=num_ids).to(device)
    elif args.backbone == 'shufflenetv2':
        model = shufflenetv2.ShuffleNetV2(anchors,
                                          num_classes=args.num_classes,
                                          num_ids=num_ids,
                                          model_size=args.thin,
                                          box_loss=args.box_loss,
                                          cls_loss=args.cls_loss).to(device)
    else:
        print('unknown backbone architecture!')
        sys.exit(0)
    if args.checkpoint:
        model.load_state_dict(torch.load(args.checkpoint))
    lr_min = 0.00025
    params = [p for p in model.parameters() if p.requires_grad]
    backbone_neck_params, detection_params, identity_params = grouping_model_params(
        model)
    if args.optim == 'sgd':
        # optimizer = torch.optim.SGD(params, lr=args.lr,
        #     momentum=args.momentum, weight_decay=args.weight_decay)
        optimizer = torch.optim.SGD([{
            'params': backbone_neck_params
        }, {
            'params': detection_params,
            'lr': args.lr * args.lr_coeff[1]
        }, {
            'params': identity_params,
            'lr': args.lr * args.lr_coeff[2]
        }],
                                    lr=(args.lr - lr_min),
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

    if args.freeze_bn:
        for name, param in model.named_parameters():
            if 'norm' in name:
                param.requires_grad = False
                logger.info('freeze {}'.format(name))
            else:
                param.requires_grad = True

    trainer = f'{args.workspace}/checkpoint/trainer-ckpt.pth'
    if args.resume:
        trainer_state = torch.load(trainer)
        optimizer.load_state_dict(trainer_state['optimizer'])

    def lr_lambda(batch):
        return 0.5 * math.cos(
            (batch % len(data_loader)) / (len(data_loader) - 1) * math.pi)

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    start_epoch = 0
    logger.info(args)
    logger.info('Start training from epoch {}'.format(start_epoch))
    model_path = f'{args.workspace}/checkpoint/{args.savename}-ckpt-%03d.pth'
    size = shared_size.numpy().tolist()

    for epoch in range(start_epoch, args.epochs):
        model.train()
        logger.info(('%8s%10s%10s' + '%10s' * 8) %
                    ('Epoch', 'Batch', 'SIZE', 'LBOX', 'LCLS', 'LIDE', 'LOSS',
                     'SBOX', 'SCLS', 'SIDE', 'LR'))

        rmetrics = defaultdict(float)
        optimizer.zero_grad()
        for batch, (images, targets) in enumerate(data_loader):
            for i, g in enumerate(optimizer.param_groups):
                g['lr'] += (args.lr - lr_min) * 0.5 + lr_min
            loss, metrics = model(images.to(device), targets.to(device), size)
            loss.backward()

            if args.sparsity:
                model.correct_bn_grad(args.lamb)

            num_batches = epoch * len(data_loader) + batch + 1
            if ((batch + 1) % args.accumulated_batches
                    == 0) or (batch == len(data_loader) - 1):
                optimizer.step()
                optimizer.zero_grad()

            for k, v in metrics.items():
                rmetrics[k] = (rmetrics[k] * batch + metrics[k]) / (batch + 1)

            fmt = tuple([('%g/%g') % (epoch, args.epochs), ('%g/%g') % (batch,
                len(data_loader)), ('%gx%g') % (size[0], size[1])] + \
                list(rmetrics.values()) + [optimizer.param_groups[0]['lr']])
            if batch % args.print_interval == 0:
                logger.info(('%8s%10s%10s' + '%10.3g' *
                             (len(rmetrics.values()) + 1)) % fmt)

            size = scale_sampler(num_batches)
            shared_size[0], shared_size[1] = size[0], size[1]
            lr_scheduler.step()

        torch.save(model.state_dict(), f"{model_path}" % epoch)
        torch.save(
            {
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict()
            }, trainer)

        if epoch >= args.eval_epoch:
            pass
예제 #4
0
def main(args):
    try:
        mp.set_start_method('spawn')
    except RuntimeError:
        pass

    utils.make_workspace_dirs(args.workspace)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    in_size = [int(s) for s in args.in_size.split(',')]
    scale_step = [int(ss) for ss in args.scale_step.split(',')]
    anchors = np.loadtxt(os.path.join(args.dataset, 'anchors.txt'))
    scale_sampler = utils.TrainScaleSampler(scale_step, args.rescale_freq)
    shared_size = torch.IntTensor(in_size).share_memory_()

    dataset = ds.CustomDataset(args.dataset, 'train')
    collate_fn = partial(ds.collate_fn, in_size=shared_size, train=True)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              args.batch_size,
                                              True,
                                              num_workers=args.workers,
                                              collate_fn=collate_fn,
                                              pin_memory=args.pin)

    model = darknet.DarkNet(anchors, in_size,
                            num_classes=args.num_classes).to(device)
    if args.checkpoint:
        print(f'load {args.checkpoint}')
        model.load_state_dict(torch.load(args.checkpoint))
    if args.sparsity:
        model.load_prune_permit('model/prune_permit.json')

    criterion = yolov3.YOLOv3Loss(args.num_classes, anchors)
    decoder = yolov3.YOLOv3EvalDecoder(in_size, args.num_classes, anchors)

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

    trainer = f'{args.workspace}/checkpoint/trainer-ckpt.pth'
    if args.resume:
        trainer_state = torch.load(trainer)
        optimizer.load_state_dict(trainer_state['optimizer'])

    milestones = [int(ms) for ms in args.milestones.split(',')]

    def lr_lambda(iter):
        if iter < args.warmup:
            return pow(iter / args.warmup, 4)
        factor = 1
        for i in milestones:
            factor *= pow(args.lr_gamma, int(iter > i))
        return factor

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    if args.resume:
        start_epoch = trainer_state['epoch'] + 1
        lr_scheduler.load_state_dict(trainer_state['lr_scheduler'])
    else:
        start_epoch = 0
    print(f'Start training from epoch {start_epoch}')

    for epoch in range(start_epoch, args.epochs):
        msgs = train_one_epoch(model, criterion, optimizer, lr_scheduler,
                               data_loader, epoch, args.interval, shared_size,
                               scale_sampler, device, args.sparsity, args.lamb)
        utils.print_training_message(args.workspace, epoch + 1, msgs,
                                     args.batch_size)
        torch.save(
            model.state_dict(),
            f"{args.workspace}/checkpoint/{args.savename}-ckpt-%03d.pth" %
            epoch)
        torch.save(
            {
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict()
            }, trainer)

        if epoch >= args.eval_epoch:
            pass