示例#1
0
    def __init__(self, anchors, num_classes=1, num_ids=0, model_size='2.0x'):
        super(ShuffleNetV2, self).__init__()
        self.anchors = anchors
        self.num_classes = num_classes
        self.detection_channels = (5 + self.num_classes) * 4
        self.embedding_channels = 128

        if model_size == '0.5x':
            self.stage_out_channels = [-1, 24, 48, 96, 192, 128, 128, 128]
        elif model_size == '1.0x':
            self.stage_out_channels = [-1, 24, 116, 232, 464, 512, 256, 128]
        elif model_size == '1.5x':
            self.stage_out_channels = [-1, 24, 176, 352, 704, 512, 256, 128]
        elif model_size == '2.0x':
            self.stage_out_channels = [-1, 24, 244, 488, 976, 512, 256, 128]
        else:
            raise NotImplementedError

        # Backbone

        in_channels = 3
        out_channels = self.stage_out_channels[1]
        self.conv1 = torch.nn.Sequential(\
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
            torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.ReLU(inplace=True))

        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.stage2 = []
        in_channels = out_channels
        out_channels = self.stage_out_channels[2]
        self.stage2.append(
            ShuffleNetV2Block(in_channels,
                              out_channels,
                              mid_channels=out_channels // 2,
                              kernel_size=3,
                              stride=2))
        in_channels = out_channels
        for r in range(3):
            self.stage2.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.stage2 = torch.nn.Sequential(*self.stage2)

        self.stage3 = []
        out_channels = self.stage_out_channels[3]
        self.stage3.append(
            ShuffleNetV2Block(in_channels,
                              out_channels,
                              mid_channels=out_channels // 2,
                              kernel_size=3,
                              stride=2))
        in_channels = out_channels
        for r in range(7):
            self.stage3.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.stage3 = torch.nn.Sequential(*self.stage3)

        self.stage4 = []
        out_channels = self.stage_out_channels[4]
        self.stage4.append(
            ShuffleNetV2Block(in_channels,
                              out_channels,
                              mid_channels=out_channels // 2,
                              kernel_size=3,
                              stride=2))
        in_channels = out_channels
        for r in range(3):
            self.stage4.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.stage4 = torch.nn.Sequential(*self.stage4)

        # YOLO1 192->128

        in_channels = out_channels
        out_channels = self.stage_out_channels[5]
        self.conv5 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels,
                            out_channels,
                            kernel_size=1,
                            stride=1,
                            padding=0,
                            bias=False),
            torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.ReLU(inplace=True))

        in_channels = out_channels
        self.shbk6 = []
        for repeat in range(3):
            self.shbk6.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.shbk6 = torch.nn.Sequential(*self.shbk6)
        self.conv7 = torch.nn.Conv2d(in_channels,
                                     out_channels=self.detection_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1,
                                     bias=True)

        self.shbk8 = []
        for repeat in range(3):
            self.shbk8.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.shbk8 = torch.nn.Sequential(*self.shbk8)
        self.conv9 = torch.nn.Conv2d(in_channels,
                                     out_channels=self.embedding_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1,
                                     bias=True)

        # YOLO2 128+96=224->128

        in_channels = self.stage_out_channels[3] + self.stage_out_channels[5]
        out_channels = self.stage_out_channels[6]
        # self.conv10 = torch.nn.Sequential(
        #     torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
        #     torch.nn.BatchNorm2d(num_features=out_channels),
        #     torch.nn.ReLU(inplace=True)
        # )
        self.conv10 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels,
                            in_channels,
                            kernel_size=3,
                            stride=1,
                            padding=1,
                            groups=in_channels,
                            bias=False),
            torch.nn.BatchNorm2d(num_features=in_channels),
            torch.nn.Conv2d(in_channels,
                            out_channels,
                            kernel_size=1,
                            stride=1,
                            padding=0,
                            bias=False),
            torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.ReLU(inplace=True))

        in_channels = out_channels
        self.shbk11 = []
        for repeat in range(3):
            self.shbk11.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.shbk11 = torch.nn.Sequential(*self.shbk11)
        self.conv12 = torch.nn.Conv2d(in_channels,
                                      out_channels=self.detection_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1,
                                      bias=True)

        self.shbk13 = []
        for repeat in range(3):
            self.shbk13.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.shbk13 = torch.nn.Sequential(*self.shbk13)
        self.conv14 = torch.nn.Conv2d(in_channels,
                                      out_channels=self.embedding_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1,
                                      bias=True)

        # YOLO3 128+48=176->128

        in_channels = self.stage_out_channels[2] + self.stage_out_channels[6]
        out_channels = self.stage_out_channels[7]
        # self.conv15 = torch.nn.Sequential(
        #     torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
        #     torch.nn.BatchNorm2d(num_features=out_channels),
        #     torch.nn.ReLU(inplace=True)
        # )
        self.conv15 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels,
                            in_channels,
                            kernel_size=3,
                            stride=1,
                            padding=1,
                            groups=in_channels,
                            bias=False),
            torch.nn.BatchNorm2d(num_features=in_channels),
            torch.nn.Conv2d(in_channels,
                            out_channels,
                            kernel_size=1,
                            stride=1,
                            padding=0,
                            bias=False),
            torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.ReLU(inplace=True))

        in_channels = out_channels
        self.shbk16 = []
        for repeat in range(3):
            self.shbk16.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.shbk16 = torch.nn.Sequential(*self.shbk16)
        self.conv17 = torch.nn.Conv2d(in_channels,
                                      out_channels=self.detection_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1,
                                      bias=True)

        self.shbk18 = []
        for repeat in range(3):
            self.shbk18.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.shbk18 = torch.nn.Sequential(*self.shbk18)
        self.conv19 = torch.nn.Conv2d(in_channels,
                                      out_channels=self.embedding_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1,
                                      bias=True)
        '''Shared identifiers classifier'''

        self.classifier = torch.nn.Linear(
            self.embedding_channels,
            num_ids) if num_ids > 0 else torch.nn.Sequential()
        self.criterion = yolov3.YOLOv3Loss(
            num_classes, anchors, num_ids, embd_dim=self.embedding_channels
        ) if num_ids > 0 else torch.nn.Sequential()

        self.__init_weights()
示例#2
0
    def __init__(self, anchors, num_classes=1, num_ids=0, model_size='2.0x'):
        super(ShuffleNetV2, self).__init__()
        self.anchors = anchors
        self.num_classes = num_classes
        self.detection_channels = (5 + self.num_classes) * 4
        self.embedding_channels = 128

        if model_size == '0.5x':
            self.stage_out_channels = [-1, 24, 48, 96, 192, 512, 256, 128]
        elif model_size == '1.0x':
            self.stage_out_channels = [-1, 24, 116, 232, 464, 512, 256, 128]
        elif model_size == '1.5x':
            self.stage_out_channels = [-1, 24, 176, 352, 704, 512, 256, 128]
        elif model_size == '2.0x':
            self.stage_out_channels = [-1, 24, 244, 488, 976, 512, 256, 128]
        else:
            raise NotImplementedError

        # Backbone

        in_channels = 3
        out_channels = self.stage_out_channels[1]
        self.conv1 = torch.nn.Sequential(\
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
            torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.ReLU(inplace=True))

        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.stage2 = []
        in_channels = out_channels
        out_channels = self.stage_out_channels[2]
        self.stage2.append(
            ShuffleNetV2Block(in_channels,
                              out_channels,
                              mid_channels=out_channels // 2,
                              kernel_size=3,
                              stride=2))
        in_channels = out_channels
        for r in range(3):
            self.stage2.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.stage2 = torch.nn.Sequential(*self.stage2)

        self.stage3 = []
        out_channels = self.stage_out_channels[3]
        self.stage3.append(
            ShuffleNetV2Block(in_channels,
                              out_channels,
                              mid_channels=out_channels // 2,
                              kernel_size=3,
                              stride=2))
        in_channels = out_channels
        for r in range(7):
            self.stage3.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.stage3 = torch.nn.Sequential(*self.stage3)

        self.stage4 = []
        out_channels = self.stage_out_channels[4]
        self.stage4.append(
            ShuffleNetV2Block(in_channels,
                              out_channels,
                              mid_channels=out_channels // 2,
                              kernel_size=3,
                              stride=2))
        in_channels = out_channels
        for r in range(3):
            self.stage4.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.stage4 = torch.nn.Sequential(*self.stage4)

        # YOLO1 192->192->512

        self.stage5 = []
        in_channels = out_channels
        for repeat in range(1):
            self.stage5.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.stage5 = torch.nn.Sequential(*self.stage5)
        out_channels = self.stage_out_channels[5]

        # fusion groups
        self.conv6 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels,
                            out_channels,
                            kernel_size=3,
                            stride=1,
                            padding=1,
                            bias=False),
            torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.ReLU(inplace=True))

        in_channels = out_channels
        self.conv11 = torch.nn.Conv2d(in_channels,
                                      out_channels=self.detection_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1,
                                      bias=True)
        self.conv12 = torch.nn.Conv2d(in_channels,
                                      out_channels=self.embedding_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1,
                                      bias=True)

        # YOLO2 96+512=608->608->256

        in_channels = self.stage_out_channels[3] + out_channels
        out_channels = in_channels
        self.stage7 = []
        for repeat in range(1):
            self.stage7.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.stage7 = torch.nn.Sequential(*self.stage7)

        # fusion groups
        out_channels = self.stage_out_channels[6]
        self.conv8 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels,
                            out_channels,
                            kernel_size=3,
                            stride=1,
                            padding=1,
                            bias=False),
            torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.ReLU(inplace=True))

        in_channels = out_channels
        self.conv13 = torch.nn.Conv2d(in_channels,
                                      out_channels=self.detection_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1,
                                      bias=True)
        self.conv14 = torch.nn.Conv2d(in_channels,
                                      out_channels=self.embedding_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1,
                                      bias=True)

        # YOLO3 48+256=304->304->128

        in_channels = self.stage_out_channels[2] + out_channels
        out_channels = in_channels
        self.stage9 = []
        for repeat in range(1):
            self.stage9.append(
                ShuffleNetV2Block(in_channels // 2,
                                  out_channels,
                                  mid_channels=out_channels // 2,
                                  kernel_size=3,
                                  stride=1))
        self.stage9 = torch.nn.Sequential(*self.stage9)

        # fusion groups
        out_channels = self.stage_out_channels[7]
        self.conv10 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels,
                            out_channels,
                            kernel_size=3,
                            stride=1,
                            padding=1,
                            bias=False),
            torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.ReLU(inplace=True))

        in_channels = out_channels
        self.conv15 = torch.nn.Conv2d(in_channels,
                                      out_channels=self.detection_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1,
                                      bias=True)
        self.conv16 = torch.nn.Conv2d(in_channels,
                                      out_channels=self.embedding_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1,
                                      bias=True)
        '''Shared identifiers classifier'''

        self.classifier = torch.nn.Linear(
            self.embedding_channels,
            num_ids) if num_ids > 0 else torch.nn.Sequential()
        self.criterion = yolov3.YOLOv3Loss(
            num_classes, anchors, num_ids, embd_dim=self.embedding_channels
        ) if num_ids > 0 else torch.nn.Sequential()

        self.__init_weights()
def main(args):
    try:
        mp.set_start_method('spawn')
    except RuntimeError:
        pass
    
    utils.make_workspace_dirs(args.workspace)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    in_size = [int(insz) for insz in args.in_size.split(',')]
    scale_step = [int(ss) for ss in args.scale_step.split(',')]
    anchors = np.loadtxt(os.path.join(args.dataset, 'anchors.txt'))
    scale_sampler = utils.TrainScaleSampler(scale_step, args.rescale_freq)
    shared_size = torch.IntTensor(in_size).share_memory_()

    dataset_train = ds.CustomDataset(args.dataset, 'train')
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        collate_fn=partial(ds.collate_fn, in_size=shared_size, train=True),
        pin_memory=args.pin)
    
    dataset_valid = ds.CustomDataset(args.dataset, 'test')
    data_loader_valid = torch.utils.data.DataLoader(
        dataset=dataset_valid,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        collate_fn=partial(ds.collate_fn, in_size=torch.IntTensor(in_size), train=False),
        pin_memory=args.pin)

    if args.checkpoint:
        print(f'load {args.checkpoint}')
        model = torch.load(args.checkpoint).to(device)
    else:
        print('please set fine tune model first!')
        return
    
    criterion = yolov3.YOLOv3Loss(args.num_classes, anchors)
    decoder = yolov3.YOLOv3EvalDecoder(in_size, args.num_classes, anchors)
    if args.test_only:
        mAP = eval.evaluate(model, data_loader_valid, device, args.num_classes)
        print(f'mAP of current model on validation dataset:%.2f%%' % (mAP * 100))
        return
    
    params = [p for p in model.parameters() if p.requires_grad]
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)

    if args.resume:
        trainer_state = torch.load(f'{args.workspace}/checkpoint/trainer-ckpt.pth')
        optimizer.load_state_dict(trainer_state['optimizer'])
 
    milestones = [int(ms) for ms in args.milestones.split(',')]
    def lr_lambda(iter):
        if iter < args.warmup:
            return pow(iter / args.warmup, 4)
        factor = 1
        for i in milestones:
            factor *= pow(args.lr_gamma, int(iter > i))
        return factor

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    if args.resume:
        start_epoch = trainer_state['epoch'] + 1
        lr_scheduler.load_state_dict(trainer_state['lr_scheduler'])
    else:
        start_epoch = 0
    print(f'Start training from epoch {start_epoch}')

    best_mAP = 0
    for epoch in range(start_epoch, args.epochs):
        msgs = train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, epoch, args.interval, shared_size, scale_sampler, device, args.sparsity, args.lamb)
        utils.print_training_message(epoch + 1, msgs, args.batch_size)
        torch.save(model, f"{args.workspace}/checkpoint/{args.savename}-ckpt-%03d.pth" % epoch)
        torch.save({
            'epoch' : epoch,
            'optimizer' : optimizer.state_dict(),
            'lr_scheduler' : lr_scheduler.state_dict()}, f'{args.workspace}/checkpoint/trainer-ckpt.pth')
        
        if epoch >= args.eval_epoch:
            mAP = eval.evaluate(model, decoder, data_loader_valid, device, args.num_classes)
            with open(f'{args.workspace}/log/mAP.txt', 'a') as file:
                file.write(f'{epoch} {mAP}\n')
                file.close()
            print(f'Current mAP:%.2f%%' % (mAP * 100))
示例#4
0
    def __init__(self, anchors, num_classes=1, num_ids=0):
        super(DarkNet, self).__init__()
        self.num_classes = num_classes
        self.momentum = 0.1
        self.negative_slope = 0.1
        self.detection_channels = (5 + self.num_classes) * 4
        self.embedding_channels = 512
        '''backbone'''

        self.cbrl1 = ConvBnReLU(in_channels=3,
                                out_channels=32,
                                kernel_size=3,
                                stride=1,
                                padding=1,
                                momentum=self.momentum,
                                negative_slope=self.negative_slope)
        self.zpad1 = torch.nn.ZeroPad2d(padding=(1, 0, 1, 0))
        self.cbrl2 = ConvBnReLU(in_channels=32,
                                out_channels=64,
                                kernel_size=3,
                                stride=2,
                                padding=0,
                                momentum=self.momentum,
                                negative_slope=self.negative_slope)

        self.stage1 = Residual(in_channels=64,
                               mid_channels=32,
                               out_channels=64,
                               momentum=self.momentum,
                               negative_slope=self.negative_slope)
        self.zpad2 = torch.nn.ZeroPad2d(padding=(1, 0, 1, 0))
        self.cbrl3 = ConvBnReLU(in_channels=64,
                                out_channels=128,
                                kernel_size=3,
                                stride=2,
                                padding=0,
                                momentum=self.momentum,
                                negative_slope=self.negative_slope)

        self.stage2 = []
        for repeate in range(2):
            self.stage2.append(
                Residual(in_channels=128,
                         mid_channels=64,
                         out_channels=128,
                         momentum=self.momentum,
                         negative_slope=self.negative_slope))
        self.stage2 = torch.nn.Sequential(*self.stage2)
        self.zpad3 = torch.nn.ZeroPad2d(padding=(1, 0, 1, 0))
        self.cbrl4 = ConvBnReLU(in_channels=128,
                                out_channels=256,
                                kernel_size=3,
                                stride=2,
                                padding=0,
                                momentum=self.momentum,
                                negative_slope=self.negative_slope)

        self.stage3 = []
        for repeate in range(8):
            self.stage3.append(
                Residual(in_channels=256,
                         mid_channels=128,
                         out_channels=256,
                         momentum=self.momentum,
                         negative_slope=self.negative_slope))
        self.stage3 = torch.nn.Sequential(*self.stage3)
        self.zpad4 = torch.nn.ZeroPad2d(padding=(1, 0, 1, 0))
        self.cbrl5 = ConvBnReLU(in_channels=256,
                                out_channels=512,
                                kernel_size=3,
                                stride=2,
                                padding=0,
                                momentum=self.momentum,
                                negative_slope=self.negative_slope)

        self.stage4 = []
        for repeate in range(8):
            self.stage4.append(
                Residual(in_channels=512,
                         mid_channels=256,
                         out_channels=512,
                         momentum=self.momentum,
                         negative_slope=self.negative_slope))
        self.stage4 = torch.nn.Sequential(*self.stage4)
        self.zpad5 = torch.nn.ZeroPad2d(padding=(1, 0, 1, 0))
        self.cbrl6 = ConvBnReLU(in_channels=512,
                                out_channels=1024,
                                kernel_size=3,
                                stride=2,
                                padding=0,
                                momentum=self.momentum,
                                negative_slope=self.negative_slope)

        self.stage5 = []
        for repeate in range(4):
            self.stage5.append(
                Residual(in_channels=1024,
                         mid_channels=512,
                         out_channels=1024,
                         momentum=self.momentum,
                         negative_slope=self.negative_slope))
        self.stage5 = torch.nn.Sequential(*self.stage5)
        '''YOLO1'''

        self.pair1 = []
        for repeate in range(2):
            self.pair1.append(
                ConvBnReLU(in_channels=1024,
                           out_channels=512,
                           kernel_size=1,
                           stride=1,
                           padding=0,
                           momentum=self.momentum,
                           negative_slope=self.negative_slope))
            self.pair1.append(
                ConvBnReLU(in_channels=512,
                           out_channels=1024,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           momentum=self.momentum,
                           negative_slope=self.negative_slope))
        self.pair1 = torch.nn.Sequential(*self.pair1)

        self.cbrl7 = ConvBnReLU(in_channels=1024,
                                out_channels=512,
                                kernel_size=1,
                                stride=1,
                                padding=0,
                                momentum=self.momentum,
                                negative_slope=self.negative_slope)
        self.cbrl8 = ConvBnReLU(in_channels=512,
                                out_channels=1024,
                                kernel_size=3,
                                stride=1,
                                padding=1,
                                momentum=self.momentum,
                                negative_slope=self.negative_slope)
        self.conv1 = torch.nn.Conv2d(in_channels=1024,
                                     out_channels=self.detection_channels,
                                     kernel_size=1,
                                     padding=0,
                                     bias=True)

        self.route5 = Route()
        self.conv4 = torch.nn.Conv2d(in_channels=512,
                                     out_channels=self.embedding_channels,
                                     kernel_size=3,
                                     padding=1,
                                     bias=True)
        self.route6 = Route()
        '''YOLO2'''

        self.route1 = Route()
        self.cbrl9 = ConvBnReLU(in_channels=512,
                                out_channels=256,
                                kernel_size=1,
                                stride=1,
                                padding=0,
                                momentum=self.momentum,
                                negative_slope=self.negative_slope)
        self.upsample1 = Upsample(scale_factor=2)
        self.route2 = Route()
        self.cbrl10 = ConvBnReLU(in_channels=768,
                                 out_channels=256,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0,
                                 momentum=self.momentum,
                                 negative_slope=self.negative_slope)

        self.pair2 = []
        for repeate in range(2):
            self.pair2.append(
                ConvBnReLU(in_channels=256,
                           out_channels=512,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           momentum=self.momentum,
                           negative_slope=self.negative_slope))
            self.pair2.append(
                ConvBnReLU(in_channels=512,
                           out_channels=256,
                           kernel_size=1,
                           stride=1,
                           padding=0,
                           momentum=self.momentum,
                           negative_slope=self.negative_slope))
        self.pair2 = torch.nn.Sequential(*self.pair2)

        self.cbrl11 = ConvBnReLU(in_channels=256,
                                 out_channels=512,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1,
                                 momentum=self.momentum,
                                 negative_slope=self.negative_slope)
        self.conv2 = torch.nn.Conv2d(in_channels=512,
                                     out_channels=self.detection_channels,
                                     kernel_size=1,
                                     padding=0,
                                     bias=True)

        self.route7 = Route()
        self.conv5 = torch.nn.Conv2d(in_channels=256,
                                     out_channels=self.embedding_channels,
                                     kernel_size=3,
                                     padding=1,
                                     bias=True)
        self.route8 = Route()
        '''YOLO3'''

        self.route3 = Route()
        self.cbrl12 = ConvBnReLU(in_channels=256,
                                 out_channels=128,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0,
                                 momentum=self.momentum,
                                 negative_slope=self.negative_slope)
        self.upsample2 = Upsample(scale_factor=2)
        self.route4 = Route()
        self.cbrl13 = ConvBnReLU(in_channels=384,
                                 out_channels=128,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0,
                                 momentum=self.momentum,
                                 negative_slope=self.negative_slope)

        self.pair3 = []
        for repeate in range(2):
            self.pair3.append(
                ConvBnReLU(in_channels=128,
                           out_channels=256,
                           kernel_size=3,
                           stride=1,
                           padding=1,
                           momentum=self.momentum,
                           negative_slope=self.negative_slope))
            self.pair3.append(
                ConvBnReLU(in_channels=256,
                           out_channels=128,
                           kernel_size=1,
                           stride=1,
                           padding=0,
                           momentum=self.momentum,
                           negative_slope=self.negative_slope))
        self.pair3 = torch.nn.Sequential(*self.pair3)

        self.cbrl14 = ConvBnReLU(in_channels=128,
                                 out_channels=256,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1,
                                 momentum=self.momentum,
                                 negative_slope=self.negative_slope)
        self.conv3 = torch.nn.Conv2d(in_channels=256,
                                     out_channels=self.detection_channels,
                                     kernel_size=1,
                                     padding=0,
                                     bias=True)

        self.route9 = Route()
        self.conv6 = torch.nn.Conv2d(in_channels=128,
                                     out_channels=self.embedding_channels,
                                     kernel_size=3,
                                     padding=1,
                                     bias=True)
        self.route10 = Route()
        '''Shared identifiers classifier'''

        self.classifier = torch.nn.Linear(
            self.embedding_channels,
            num_ids) if num_ids > 0 else torch.nn.Sequential()
        self.criterion = yolov3.YOLOv3Loss(num_classes, anchors, num_ids)

        self.__init_weights()
示例#5
0
def main(args):
    try:
        mp.set_start_method('spawn')
    except RuntimeError:
        pass

    utils.make_workspace_dirs(args.workspace)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    in_size = [int(s) for s in args.in_size.split(',')]
    scale_step = [int(ss) for ss in args.scale_step.split(',')]
    anchors = np.loadtxt(os.path.join(args.dataset, 'anchors.txt'))
    scale_sampler = utils.TrainScaleSampler(scale_step, args.rescale_freq)
    shared_size = torch.IntTensor(in_size).share_memory_()

    dataset = ds.CustomDataset(args.dataset, 'train')
    collate_fn = partial(ds.collate_fn, in_size=shared_size, train=True)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              args.batch_size,
                                              True,
                                              num_workers=args.workers,
                                              collate_fn=collate_fn,
                                              pin_memory=args.pin)

    model = darknet.DarkNet(anchors, in_size,
                            num_classes=args.num_classes).to(device)
    if args.checkpoint:
        print(f'load {args.checkpoint}')
        model.load_state_dict(torch.load(args.checkpoint))
    if args.sparsity:
        model.load_prune_permit('model/prune_permit.json')

    criterion = yolov3.YOLOv3Loss(args.num_classes, anchors)
    decoder = yolov3.YOLOv3EvalDecoder(in_size, args.num_classes, anchors)

    params = [p for p in model.parameters() if p.requires_grad]
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

    trainer = f'{args.workspace}/checkpoint/trainer-ckpt.pth'
    if args.resume:
        trainer_state = torch.load(trainer)
        optimizer.load_state_dict(trainer_state['optimizer'])

    milestones = [int(ms) for ms in args.milestones.split(',')]

    def lr_lambda(iter):
        if iter < args.warmup:
            return pow(iter / args.warmup, 4)
        factor = 1
        for i in milestones:
            factor *= pow(args.lr_gamma, int(iter > i))
        return factor

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    if args.resume:
        start_epoch = trainer_state['epoch'] + 1
        lr_scheduler.load_state_dict(trainer_state['lr_scheduler'])
    else:
        start_epoch = 0
    print(f'Start training from epoch {start_epoch}')

    for epoch in range(start_epoch, args.epochs):
        msgs = train_one_epoch(model, criterion, optimizer, lr_scheduler,
                               data_loader, epoch, args.interval, shared_size,
                               scale_sampler, device, args.sparsity, args.lamb)
        utils.print_training_message(args.workspace, epoch + 1, msgs,
                                     args.batch_size)
        torch.save(
            model.state_dict(),
            f"{args.workspace}/checkpoint/{args.savename}-ckpt-%03d.pth" %
            epoch)
        torch.save(
            {
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict()
            }, trainer)

        if epoch >= args.eval_epoch:
            pass